diff --git a/src/installer/pkg/sfx/Microsoft.NETCore.App/PackageOverrides.txt b/src/installer/pkg/sfx/Microsoft.NETCore.App/PackageOverrides.txt index f49c727fc22b98..f06bb0137eee58 100644 --- a/src/installer/pkg/sfx/Microsoft.NETCore.App/PackageOverrides.txt +++ b/src/installer/pkg/sfx/Microsoft.NETCore.App/PackageOverrides.txt @@ -187,6 +187,7 @@ System.IO.Pipes|4.3.0 System.IO.Pipes.AccessControl|5.0.0 System.IO.UnmanagedMemoryStream|4.3.0 System.Linq|4.3.0 +System.Linq.AsyncEnumerable|${ProductVersion} System.Linq.Expressions|4.3.0 System.Linq.Parallel|4.3.0 System.Linq.Queryable|4.3.0 diff --git a/src/libraries/NetCoreAppLibrary.props b/src/libraries/NetCoreAppLibrary.props index a3f63978bf92da..8e94beeba018ee 100644 --- a/src/libraries/NetCoreAppLibrary.props +++ b/src/libraries/NetCoreAppLibrary.props @@ -80,6 +80,7 @@ System.IO.Pipelines; System.IO.UnmanagedMemoryStream; System.Linq; + System.Linq.AsyncEnumerable; System.Linq.Expressions; System.Linq.Parallel; System.Linq.Queryable; diff --git a/src/libraries/System.Linq.AsyncEnumerable/Directory.Build.props b/src/libraries/System.Linq.AsyncEnumerable/Directory.Build.props new file mode 100644 index 00000000000000..e8d65546d0c807 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/Directory.Build.props @@ -0,0 +1,6 @@ + + + + Microsoft + + diff --git a/src/libraries/System.Linq.AsyncEnumerable/README.md b/src/libraries/System.Linq.AsyncEnumerable/README.md new file mode 100644 index 00000000000000..c13b9a796b18ce --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/README.md @@ -0,0 +1,18 @@ +# System.Linq.AsyncEnumerable + +Language-Integrated Query (LINQ) is the name for a set of technologies based on the integration of query capabilities directly into the C# language. + +Documentation can be found at https://learn.microsoft.com/dotnet/api/system.linq. + +This library provides an implementation of LINQ APIs for `IAsyncEnumerable`. + +## Contribution Bar + +- [x] [We consider new features, new APIs and performance changes](../../libraries/README.md#primary-bar) +- [x] [We consider PRs that target this library for new source code analyzers](../../libraries/README.md#secondary-bars) + +See the [Help Wanted](https://github.com/dotnet/runtime/issues?q=is%3Aissue+is%3Aopen+label%3Aarea-System.Linq+label%3A%22help+wanted%22+) issues. + +## Deployment + +System.Linq.AsyncEnumerable is shipped as part of the .NET shared framework and as a NuGet package. diff --git a/src/libraries/System.Linq.AsyncEnumerable/System.Linq.AsyncEnumerable.sln b/src/libraries/System.Linq.AsyncEnumerable/System.Linq.AsyncEnumerable.sln new file mode 100644 index 00000000000000..ca0fda7c6a912c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/System.Linq.AsyncEnumerable.sln @@ -0,0 +1,149 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.13.35602.250 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestUtilities", "..\Common\tests\TestUtilities\TestUtilities.csproj", "{AF1B1B01-A4EC-45F4-AE51-CC1FA7892181}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Collections", "..\System.Collections\ref\System.Collections.csproj", "{3A8560D8-0E79-4BDE-802A-C96C7FE98258}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Linq.AsyncEnumerable", "ref\System.Linq.AsyncEnumerable.csproj", "{7E4C1F09-B4F2-470E-9E7B-2C386E93D657}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Linq.AsyncEnumerable", "src\System.Linq.AsyncEnumerable.csproj", "{14B966BB-CE23-4432-ADBB-89974389AC1D}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Linq.AsyncEnumerable.Tests", "tests\System.Linq.AsyncEnumerable.Tests.csproj", "{80A4051B-4A36-4A8B-BA43-A5AB8AA959F3}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ComInterfaceGenerator", "..\System.Runtime.InteropServices\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj", "{9A13A12F-C924-43AF-94AF-6F1B33582D27}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DownlevelLibraryImportGenerator", "..\System.Runtime.InteropServices\gen\DownlevelLibraryImportGenerator\DownlevelLibraryImportGenerator.csproj", "{C026F4C2-949D-4F73-845B-0D78993A83B0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LibraryImportGenerator", "..\System.Runtime.InteropServices\gen\LibraryImportGenerator\LibraryImportGenerator.csproj", "{4BEC631E-B5FD-453F-82A0-C95C461798EA}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Interop.SourceGeneration", "..\System.Runtime.InteropServices\gen\Microsoft.Interop.SourceGeneration\Microsoft.Interop.SourceGeneration.csproj", "{C8F0459C-15D5-4624-8CE4-E93ADF96A28C}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.Runtime", "..\System.Runtime\ref\System.Runtime.csproj", "{D3160C37-FC48-4907-8F4A-F584ED12B275}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.CodeFixProvider", "..\..\tools\illink\src\ILLink.CodeFix\ILLink.CodeFixProvider.csproj", "{E0CA3ED5-EE6C-4F7C-BCE7-EFB1D64A9CD1}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.RoslynAnalyzer", "..\..\tools\illink\src\ILLink.RoslynAnalyzer\ILLink.RoslynAnalyzer.csproj", "{3EFB74E7-616A-48C1-B43B-3F89AA5013E6}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ILLink.Tasks", "..\..\tools\illink\src\ILLink.Tasks\ILLink.Tasks.csproj", "{28ABC524-ACEE-4183-A64A-49E3DC830595}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\Mono.Linker.csproj", "{721DB3D9-8221-424E-BE29-084CDD20D26E}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Mono.Linker", "..\..\tools\illink\src\linker\ref\Mono.Linker.csproj", "{E19B8772-2DBD-4274-8190-F3CC0242A1C0}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{E291F4BF-7B8B-45AD-88F5-FB8B8380C126}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{18C4E23D-AB0F-45E5-A6A1-A741F6462E85}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{F8F69023-9ACD-4979-A710-39D16377AEEE}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{34793393-0347-438D-A832-2476F33C1BE3}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "gen", "gen", "{3EC69C1A-F3A3-4057-8DB0-D2ECD915AD5A}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{F1EFB29E-59BF-4165-953D-DC49A3F289DB}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "ref", "ref", "{9B6443FD-0249-4934-B885-D0A503F87DB4}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{0ADC596A-5B2E-4E5F-B5B5-DEB65A6C7E9D}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {AF1B1B01-A4EC-45F4-AE51-CC1FA7892181}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AF1B1B01-A4EC-45F4-AE51-CC1FA7892181}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AF1B1B01-A4EC-45F4-AE51-CC1FA7892181}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AF1B1B01-A4EC-45F4-AE51-CC1FA7892181}.Release|Any CPU.Build.0 = Release|Any CPU + {3A8560D8-0E79-4BDE-802A-C96C7FE98258}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3A8560D8-0E79-4BDE-802A-C96C7FE98258}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3A8560D8-0E79-4BDE-802A-C96C7FE98258}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3A8560D8-0E79-4BDE-802A-C96C7FE98258}.Release|Any CPU.Build.0 = Release|Any CPU + {7E4C1F09-B4F2-470E-9E7B-2C386E93D657}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7E4C1F09-B4F2-470E-9E7B-2C386E93D657}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7E4C1F09-B4F2-470E-9E7B-2C386E93D657}.Release|Any CPU.ActiveCfg = Release|Any CPU + {7E4C1F09-B4F2-470E-9E7B-2C386E93D657}.Release|Any CPU.Build.0 = Release|Any CPU + {14B966BB-CE23-4432-ADBB-89974389AC1D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {14B966BB-CE23-4432-ADBB-89974389AC1D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {14B966BB-CE23-4432-ADBB-89974389AC1D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {14B966BB-CE23-4432-ADBB-89974389AC1D}.Release|Any CPU.Build.0 = Release|Any CPU + {80A4051B-4A36-4A8B-BA43-A5AB8AA959F3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {80A4051B-4A36-4A8B-BA43-A5AB8AA959F3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {80A4051B-4A36-4A8B-BA43-A5AB8AA959F3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {80A4051B-4A36-4A8B-BA43-A5AB8AA959F3}.Release|Any CPU.Build.0 = Release|Any CPU + {9A13A12F-C924-43AF-94AF-6F1B33582D27}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9A13A12F-C924-43AF-94AF-6F1B33582D27}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9A13A12F-C924-43AF-94AF-6F1B33582D27}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9A13A12F-C924-43AF-94AF-6F1B33582D27}.Release|Any CPU.Build.0 = Release|Any CPU + {C026F4C2-949D-4F73-845B-0D78993A83B0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C026F4C2-949D-4F73-845B-0D78993A83B0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C026F4C2-949D-4F73-845B-0D78993A83B0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C026F4C2-949D-4F73-845B-0D78993A83B0}.Release|Any CPU.Build.0 = Release|Any CPU + {4BEC631E-B5FD-453F-82A0-C95C461798EA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4BEC631E-B5FD-453F-82A0-C95C461798EA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4BEC631E-B5FD-453F-82A0-C95C461798EA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4BEC631E-B5FD-453F-82A0-C95C461798EA}.Release|Any CPU.Build.0 = Release|Any CPU + {C8F0459C-15D5-4624-8CE4-E93ADF96A28C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C8F0459C-15D5-4624-8CE4-E93ADF96A28C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C8F0459C-15D5-4624-8CE4-E93ADF96A28C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C8F0459C-15D5-4624-8CE4-E93ADF96A28C}.Release|Any CPU.Build.0 = Release|Any CPU + {D3160C37-FC48-4907-8F4A-F584ED12B275}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D3160C37-FC48-4907-8F4A-F584ED12B275}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D3160C37-FC48-4907-8F4A-F584ED12B275}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D3160C37-FC48-4907-8F4A-F584ED12B275}.Release|Any CPU.Build.0 = Release|Any CPU + {E0CA3ED5-EE6C-4F7C-BCE7-EFB1D64A9CD1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E0CA3ED5-EE6C-4F7C-BCE7-EFB1D64A9CD1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E0CA3ED5-EE6C-4F7C-BCE7-EFB1D64A9CD1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E0CA3ED5-EE6C-4F7C-BCE7-EFB1D64A9CD1}.Release|Any CPU.Build.0 = Release|Any CPU + {3EFB74E7-616A-48C1-B43B-3F89AA5013E6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3EFB74E7-616A-48C1-B43B-3F89AA5013E6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3EFB74E7-616A-48C1-B43B-3F89AA5013E6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3EFB74E7-616A-48C1-B43B-3F89AA5013E6}.Release|Any CPU.Build.0 = Release|Any CPU + {28ABC524-ACEE-4183-A64A-49E3DC830595}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {28ABC524-ACEE-4183-A64A-49E3DC830595}.Debug|Any CPU.Build.0 = Debug|Any CPU + {28ABC524-ACEE-4183-A64A-49E3DC830595}.Release|Any CPU.ActiveCfg = Release|Any CPU + {28ABC524-ACEE-4183-A64A-49E3DC830595}.Release|Any CPU.Build.0 = Release|Any CPU + {721DB3D9-8221-424E-BE29-084CDD20D26E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {721DB3D9-8221-424E-BE29-084CDD20D26E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {721DB3D9-8221-424E-BE29-084CDD20D26E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {721DB3D9-8221-424E-BE29-084CDD20D26E}.Release|Any CPU.Build.0 = Release|Any CPU + {E19B8772-2DBD-4274-8190-F3CC0242A1C0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E19B8772-2DBD-4274-8190-F3CC0242A1C0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E19B8772-2DBD-4274-8190-F3CC0242A1C0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E19B8772-2DBD-4274-8190-F3CC0242A1C0}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {AF1B1B01-A4EC-45F4-AE51-CC1FA7892181} = {E291F4BF-7B8B-45AD-88F5-FB8B8380C126} + {3A8560D8-0E79-4BDE-802A-C96C7FE98258} = {18C4E23D-AB0F-45E5-A6A1-A741F6462E85} + {7E4C1F09-B4F2-470E-9E7B-2C386E93D657} = {18C4E23D-AB0F-45E5-A6A1-A741F6462E85} + {14B966BB-CE23-4432-ADBB-89974389AC1D} = {F8F69023-9ACD-4979-A710-39D16377AEEE} + {80A4051B-4A36-4A8B-BA43-A5AB8AA959F3} = {E291F4BF-7B8B-45AD-88F5-FB8B8380C126} + {9A13A12F-C924-43AF-94AF-6F1B33582D27} = {34793393-0347-438D-A832-2476F33C1BE3} + {C026F4C2-949D-4F73-845B-0D78993A83B0} = {34793393-0347-438D-A832-2476F33C1BE3} + {4BEC631E-B5FD-453F-82A0-C95C461798EA} = {34793393-0347-438D-A832-2476F33C1BE3} + {C8F0459C-15D5-4624-8CE4-E93ADF96A28C} = {34793393-0347-438D-A832-2476F33C1BE3} + {D3160C37-FC48-4907-8F4A-F584ED12B275} = {18C4E23D-AB0F-45E5-A6A1-A741F6462E85} + {E0CA3ED5-EE6C-4F7C-BCE7-EFB1D64A9CD1} = {3EC69C1A-F3A3-4057-8DB0-D2ECD915AD5A} + {3EFB74E7-616A-48C1-B43B-3F89AA5013E6} = {3EC69C1A-F3A3-4057-8DB0-D2ECD915AD5A} + {28ABC524-ACEE-4183-A64A-49E3DC830595} = {F1EFB29E-59BF-4165-953D-DC49A3F289DB} + {721DB3D9-8221-424E-BE29-084CDD20D26E} = {F1EFB29E-59BF-4165-953D-DC49A3F289DB} + {E19B8772-2DBD-4274-8190-F3CC0242A1C0} = {9B6443FD-0249-4934-B885-D0A503F87DB4} + {3EC69C1A-F3A3-4057-8DB0-D2ECD915AD5A} = {0ADC596A-5B2E-4E5F-B5B5-DEB65A6C7E9D} + {F1EFB29E-59BF-4165-953D-DC49A3F289DB} = {0ADC596A-5B2E-4E5F-B5B5-DEB65A6C7E9D} + {9B6443FD-0249-4934-B885-D0A503F87DB4} = {0ADC596A-5B2E-4E5F-B5B5-DEB65A6C7E9D} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {A4970D79-BF1C-4343-9070-B409DBB69F93} + EndGlobalSection + GlobalSection(SharedMSBuildProjectFiles) = preSolution + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{3efb74e7-616a-48c1-b43b-3f89aa5013e6}*SharedItemsImports = 5 + ..\..\tools\illink\src\ILLink.Shared\ILLink.Shared.projitems*{721db3d9-8221-424e-be29-084cdd20d26e}*SharedItemsImports = 5 + EndGlobalSection +EndGlobal diff --git a/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs b/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs new file mode 100644 index 00000000000000..6153d0bd56a07c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// ------------------------------------------------------------------------------ +// Changes to this file must follow the https://aka.ms/api-review process. +// ------------------------------------------------------------------------------ + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + public static System.Threading.Tasks.ValueTask AggregateAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> func, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AggregateAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func func, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AggregateAsync(this System.Collections.Generic.IAsyncEnumerable source, TAccumulate seed, System.Func> func, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AggregateAsync(this System.Collections.Generic.IAsyncEnumerable source, TAccumulate seed, System.Func func, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AggregateAsync(this System.Collections.Generic.IAsyncEnumerable source, TAccumulate seed, System.Func> func, System.Func> resultSelector, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AggregateAsync(this System.Collections.Generic.IAsyncEnumerable source, TAccumulate seed, System.Func func, System.Func resultSelector, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> AggregateBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func> seedSelector, System.Func> func, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> AggregateBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, TAccumulate seed, System.Func> func, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> AggregateBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func seedSelector, System.Func func, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> AggregateBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, TAccumulate seed, System.Func func, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask AllAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AllAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AnyAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AnyAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AnyAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Append(this System.Collections.Generic.IAsyncEnumerable source, TSource element) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask AverageAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Cast(this System.Collections.Generic.IAsyncEnumerable source) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Chunk(this System.Collections.Generic.IAsyncEnumerable source, int size) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Concat(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second) { throw null; } + public static System.Threading.Tasks.ValueTask ContainsAsync(this System.Collections.Generic.IAsyncEnumerable source, TSource value, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask CountAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask CountAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask CountAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> CountBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> CountBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } + public static System.Collections.Generic.IAsyncEnumerable DefaultIfEmpty(this System.Collections.Generic.IAsyncEnumerable source) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable DefaultIfEmpty(this System.Collections.Generic.IAsyncEnumerable source, TSource defaultValue) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable DistinctBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable DistinctBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Distinct(this System.Collections.Generic.IAsyncEnumerable source, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Threading.Tasks.ValueTask ElementAtAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Index index, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask ElementAtAsync(this System.Collections.Generic.IAsyncEnumerable source, int index, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask ElementAtOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Index index, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask ElementAtOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, int index, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Empty() { throw null; } + public static System.Collections.Generic.IAsyncEnumerable ExceptBy(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable ExceptBy(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Except(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Threading.Tasks.ValueTask FirstAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask FirstOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func> elementSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable> GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func elementSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func> elementSelector, System.Func, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable GroupBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func elementSelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable GroupJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func> outerKeySelector, System.Func> innerKeySelector, System.Func, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable GroupJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable<(int Index, TSource Item)> Index(this System.Collections.Generic.IAsyncEnumerable source) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable IntersectBy(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable IntersectBy(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Intersect(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Join(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func> outerKeySelector, System.Func> innerKeySelector, System.Func> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Join(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Threading.Tasks.ValueTask LastAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LastOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable LeftJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func> outerKeySelector, System.Func> innerKeySelector, System.Func> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable LeftJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Threading.Tasks.ValueTask LongCountAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LongCountAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask LongCountAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask MaxAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Collections.Generic.IComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask MaxByAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask MaxByAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask MinAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Collections.Generic.IComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask MinByAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask MinByAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable OfType(this System.Collections.Generic.IAsyncEnumerable source) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable OrderByDescending(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable OrderByDescending(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable OrderBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable OrderBy(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable OrderDescending(this System.Collections.Generic.IAsyncEnumerable source, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable Order(this System.Collections.Generic.IAsyncEnumerable source, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Prepend(this System.Collections.Generic.IAsyncEnumerable source, TSource element) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Range(int start, int count) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Repeat(TResult element, int count) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Reverse(this System.Collections.Generic.IAsyncEnumerable source) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable RightJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func> outerKeySelector, System.Func> innerKeySelector, System.Func> resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable RightJoin(this System.Collections.Generic.IAsyncEnumerable outer, System.Collections.Generic.IAsyncEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func>> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func>> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> collectionSelector, System.Func> resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> collectionSelector, System.Func resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> collectionSelector, System.Func resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> collectionSelector, System.Func> resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func> collectionSelector, System.Func resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func>> collectionSelector, System.Func> resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SelectMany(this System.Collections.Generic.IAsyncEnumerable source, System.Func>> collectionSelector, System.Func> resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Select(this System.Collections.Generic.IAsyncEnumerable source, System.Func> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Select(this System.Collections.Generic.IAsyncEnumerable source, System.Func selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Select(this System.Collections.Generic.IAsyncEnumerable source, System.Func> selector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Select(this System.Collections.Generic.IAsyncEnumerable source, System.Func selector) { throw null; } + public static System.Threading.Tasks.ValueTask SequenceEqualAsync(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SingleOrDefaultAsync(this System.Collections.Generic.IAsyncEnumerable source, TSource defaultValue, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SkipLast(this System.Collections.Generic.IAsyncEnumerable source, int count) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SkipWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SkipWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SkipWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable SkipWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Skip(this System.Collections.Generic.IAsyncEnumerable source, int count) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask SumAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable TakeLast(this System.Collections.Generic.IAsyncEnumerable source, int count) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable TakeWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable TakeWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable TakeWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable TakeWhile(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Take(this System.Collections.Generic.IAsyncEnumerable source, int count) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Take(this System.Collections.Generic.IAsyncEnumerable source, System.Range range) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable ThenByDescending(this System.Linq.IOrderedAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable ThenByDescending(this System.Linq.IOrderedAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable ThenBy(this System.Linq.IOrderedAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Linq.IOrderedAsyncEnumerable ThenBy(this System.Linq.IOrderedAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer = null) { throw null; } + public static System.Threading.Tasks.ValueTask ToArrayAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable ToAsyncEnumerable(this System.Collections.Generic.IEnumerable source) { throw null; } + public static System.Threading.Tasks.ValueTask> ToDictionaryAsync(this System.Collections.Generic.IAsyncEnumerable> source, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask> ToDictionaryAsync(this System.Collections.Generic.IAsyncEnumerable<(TKey Key, TValue Value)> source, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask> ToDictionaryAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask> ToDictionaryAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask> ToDictionaryAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func> elementSelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask> ToDictionaryAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func elementSelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) where TKey : notnull { throw null; } + public static System.Threading.Tasks.ValueTask> ToHashSetAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask> ToListAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask> ToLookupAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask> ToLookupAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask> ToLookupAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func> keySelector, System.Func> elementSelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Threading.Tasks.ValueTask> ToLookupAsync(this System.Collections.Generic.IAsyncEnumerable source, System.Func keySelector, System.Func elementSelector, System.Collections.Generic.IEqualityComparer? comparer = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable UnionBy(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable UnionBy(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Union(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Collections.Generic.IEqualityComparer? comparer = null) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Where(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Where(this System.Collections.Generic.IAsyncEnumerable source, System.Func predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Where(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Where(this System.Collections.Generic.IAsyncEnumerable source, System.Func> predicate) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable<(TFirst First, TSecond Second)> Zip(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable<(TFirst First, TSecond Second, TThird Third)> Zip(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Collections.Generic.IAsyncEnumerable third) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Zip(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func> resultSelector) { throw null; } + public static System.Collections.Generic.IAsyncEnumerable Zip(this System.Collections.Generic.IAsyncEnumerable first, System.Collections.Generic.IAsyncEnumerable second, System.Func resultSelector) { throw null; } + } + public partial interface IOrderedAsyncEnumerable : System.Collections.Generic.IAsyncEnumerable + { + System.Linq.IOrderedAsyncEnumerable CreateOrderedAsyncEnumerable(System.Func> keySelector, System.Collections.Generic.IComparer? comparer, bool descending); + System.Linq.IOrderedAsyncEnumerable CreateOrderedAsyncEnumerable(System.Func keySelector, System.Collections.Generic.IComparer? comparer, bool descending); + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.csproj b/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.csproj new file mode 100644 index 00000000000000..b8f3303989b913 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/ref/System.Linq.AsyncEnumerable.csproj @@ -0,0 +1,28 @@ + + + $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum) + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/PACKAGE.md b/src/libraries/System.Linq.AsyncEnumerable/src/PACKAGE.md new file mode 100644 index 00000000000000..a9e3430e28bc37 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/PACKAGE.md @@ -0,0 +1,40 @@ +## About + +The `System.Linq.AsyncEnumerable` library provides support for Language-Integrated Query (LINQ) over `IAsyncEnumerable` sequences. + +## Key Features + +* Extension methods for performing operations on `IAsyncEnumerable` sequences. + +## How to Use + +```C# +using System; +using System.IO; +using System.Linq; + +static IAsyncEnumerable DeserializeAndFilterData(Stream stream) +{ + IAsyncEnumerable cities = JsonSerializer.DeserializeAsyncEnumerable(stream); + + return from city in cities + where city.Population > 10_000 + orderby city.Name + select city; +} +``` + +## Main Types + +The main type provided by this library is: + +* `System.Linq.AsyncEnumerable` + +## Additional Documentation + +* [Overview](https://learn.microsoft.com/dotnet/csharp/linq/) +* [API documentation](https://learn.microsoft.com/dotnet/api/system.linq) + +## Feedback & Contributing + +`System.Linq.AsyncEnumerable` is released as open source under the [MIT license](https://licenses.nuget.org/MIT). Bug reports and contributions are welcome at [the GitHub repository](https://github.com/dotnet/runtime). diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/Resources/Strings.resx b/src/libraries/System.Linq.AsyncEnumerable/src/Resources/Strings.resx new file mode 100644 index 00000000000000..7c706b7926a27c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/Resources/Strings.resx @@ -0,0 +1,75 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Enumeration yielded no results + + + Sequence contains more than one element + + + Sequence contains more than one matching element + + + Sequence contains no elements + + + Sequence contains no matching element + + \ No newline at end of file diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System.Linq.AsyncEnumerable.csproj b/src/libraries/System.Linq.AsyncEnumerable/src/System.Linq.AsyncEnumerable.csproj new file mode 100644 index 00000000000000..65d7a694bb6a54 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System.Linq.AsyncEnumerable.csproj @@ -0,0 +1,103 @@ + + + + $(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum) + true + $(NoWarn);CS1998 + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AggregateAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AggregateAsync.cs new file mode 100644 index 00000000000000..4c609a4b9c5fd7 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AggregateAsync.cs @@ -0,0 +1,279 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Applies an accumulator function over a sequence. + /// The type of the elements of source. + /// An to aggregate over. + /// An accumulator function to be invoked on each element. + /// The to monitor for cancellation requests. The default is . + /// The final accumulator value. + /// is . + /// is . + /// contains no elements. + public static ValueTask AggregateAsync( + this IAsyncEnumerable source, + Func func, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, func, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func func, + CancellationToken cancellationToken) + { + TSource result; + + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + result = e.Current; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + result = func(result, e.Current); + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Applies an accumulator function over a sequence. + /// The type of the elements of source. + /// An to aggregate over. + /// An accumulator function to be invoked on each element. + /// The to monitor for cancellation requests. The default is . + /// The final accumulator value. + /// is . + /// is . + /// contains no elements. + public static ValueTask AggregateAsync( + this IAsyncEnumerable source, + Func> func, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, func, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> func, + CancellationToken cancellationToken) + { + TSource result; + + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + result = e.Current; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + result = await func(result, e.Current, cancellationToken).ConfigureAwait(false); + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Applies an accumulator function over a sequence. The specified seed value is used as the initial accumulator value. + /// The type of the elements of source. + /// The type of the accumulator value. + /// An to aggregate over. + /// The initial accumulator value. + /// An accumulator function to be invoked on each element. + /// The to monitor for cancellation requests. The default is . + /// The final accumulator value. + /// is . + /// is . + public static ValueTask AggregateAsync( + this IAsyncEnumerable source, + TAccumulate seed, + Func func, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(func); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), seed, func); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + TAccumulate seed, + Func func) + { + TAccumulate result = seed; + + await foreach (TSource element in source) + { + result = func(result, element); + } + + return result; + } + } + + /// Applies an accumulator function over a sequence. The specified seed value is used as the initial accumulator value. + /// The type of the elements of source. + /// The type of the accumulator value. + /// An to aggregate over. + /// The initial accumulator value. + /// An accumulator function to be invoked on each element. + /// The to monitor for cancellation requests. The default is . + /// The final accumulator value. + /// is . + /// is . + public static ValueTask AggregateAsync( + this IAsyncEnumerable source, TAccumulate seed, + Func> func, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, seed, func, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, TAccumulate seed, + Func> func, + CancellationToken cancellationToken = default) + { + TAccumulate result = seed; + + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + result = await func(result, element, cancellationToken).ConfigureAwait(false); + } + + return result; + } + } + + /// + /// Applies an accumulator function over a sequence. The specified seed value is + /// used as the initial accumulator value, and the specified function is used to + /// select the result value. + /// + /// The type of the elements of source. + /// The type of the accumulator value. + /// The type of the resulting value. + /// An to aggregate over. + /// The initial accumulator value. + /// An accumulator function to be invoked on each element. + /// A function to transform the final accumulator value into the result value. + /// The to monitor for cancellation requests. The default is . + /// The transformed final accumulator value. + /// is . + /// is . + /// is . + public static ValueTask AggregateAsync( + this IAsyncEnumerable source, + TAccumulate seed, + Func func, + Func resultSelector, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(func); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), seed, func, resultSelector); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + TAccumulate seed, + Func func, + Func resultSelector) + { + TAccumulate result = seed; + + await foreach (TSource element in source) + { + result = func(result, element); + } + + return resultSelector(result); + } + } + + /// + /// Applies an accumulator function over a sequence. The specified seed value is + /// used as the initial accumulator value, and the specified function is used to + /// select the result value. + /// + /// The type of the elements of source. + /// The type of the accumulator value. + /// The type of the resulting value. + /// An to aggregate over. + /// The initial accumulator value. + /// An accumulator function to be invoked on each element. + /// A function to transform the final accumulator value into the result value. + /// The to monitor for cancellation requests. The default is . + /// The transformed final accumulator value. + /// is . + /// is . + /// is . + public static ValueTask AggregateAsync( + this IAsyncEnumerable source, + TAccumulate seed, + Func> func, + Func> resultSelector, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(func); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, seed, func, resultSelector, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + TAccumulate seed, + Func> func, + Func> resultSelector, + CancellationToken cancellationToken) + { + TAccumulate result = seed; + + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + result = await func(result, element, cancellationToken).ConfigureAwait(false); + } + + return await resultSelector(result, cancellationToken).ConfigureAwait(false); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AggregateBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AggregateBy.cs new file mode 100644 index 00000000000000..6456462d950637 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AggregateBy.cs @@ -0,0 +1,311 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +#if NET +using System.Runtime.InteropServices; +#endif +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Applies an accumulator function over a sequence, grouping results by key. + /// The type of the elements of . + /// The type of the key returned by . + /// The type of the accumulator value. + /// An to aggregate over. + /// A function to extract the key for each element. + /// The initial accumulator value. + /// An accumulator function to be invoked on each element. + /// An to compare keys with. + /// An enumerable containing the aggregates corresponding to each key deriving from . + /// + /// This method is comparable to the GroupBy methods where each grouping is being aggregated into a single value + /// as opposed to allocating a collection for each group. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable> AggregateBy( + this IAsyncEnumerable source, + Func keySelector, + TAccumulate seed, + Func func, + IEqualityComparer? keyComparer = null) + where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, keySelector, seed, func, keyComparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func keySelector, + TAccumulate seed, + Func func, + IEqualityComparer? keyComparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + Dictionary dict = new(keyComparer); + + do + { + TSource value = enumerator.Current; + TKey key = keySelector(value); + +#if NET + ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists); + acc = func(exists ? acc! : seed, value); +#else + dict[key] = func(dict.TryGetValue(key, out TAccumulate? acc) ? acc : seed, value); +#endif + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + + foreach (KeyValuePair countBy in dict) + { + yield return countBy; + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Applies an accumulator function over a sequence, grouping results by key. + /// The type of the elements of . + /// The type of the key returned by . + /// The type of the accumulator value. + /// An to aggregate over. + /// A function to extract the key for each element. + /// The initial accumulator value. + /// An accumulator function to be invoked on each element. + /// An to compare keys with. + /// An enumerable containing the aggregates corresponding to each key deriving from . + /// + /// This method is comparable to the GroupBy methods where each grouping is being aggregated into a single value + /// as opposed to allocating a collection for each group. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable> AggregateBy( + this IAsyncEnumerable source, + Func> keySelector, + TAccumulate seed, + Func> func, + IEqualityComparer? keyComparer = null) + where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, keySelector, seed, func, keyComparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func> keySelector, + TAccumulate seed, + Func> func, + IEqualityComparer? keyComparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + Dictionary dict = new(keyComparer); + + do + { + TSource value = enumerator.Current; + TKey key = await keySelector(value, cancellationToken).ConfigureAwait(false); + + dict[key] = await func(dict.TryGetValue(key, out TAccumulate? acc) ? acc : seed, value, cancellationToken).ConfigureAwait(false); + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + + foreach (KeyValuePair countBy in dict) + { + yield return countBy; + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Applies an accumulator function over a sequence, grouping results by key. + /// The type of the elements of . + /// The type of the key returned by . + /// The type of the accumulator value. + /// An to aggregate over. + /// A function to extract the key for each element. + /// A factory for the initial accumulator value. + /// An accumulator function to be invoked on each element. + /// An to compare keys with. + /// An enumerable containing the aggregates corresponding to each key deriving from . + /// + /// This method is comparable to the GroupBy methods where each grouping is being aggregated into a single value + /// as opposed to allocating a collection for each group. + /// + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable> AggregateBy( + this IAsyncEnumerable source, + Func keySelector, + Func seedSelector, + Func func, + IEqualityComparer? keyComparer = null) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(seedSelector); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, keySelector, seedSelector, func, keyComparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func keySelector, + Func seedSelector, + Func func, + IEqualityComparer? keyComparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + Dictionary dict = new(keyComparer); + + do + { + TSource value = enumerator.Current; + TKey key = keySelector(value); + +#if NET + ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists); + acc = func(exists ? acc! : seedSelector(key), value); +#else + dict[key] = func(dict.TryGetValue(key, out TAccumulate? acc) ? acc : seedSelector(key), value); +#endif + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + + foreach (KeyValuePair countBy in dict) + { + yield return countBy; + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Applies an accumulator function over a sequence, grouping results by key. + /// The type of the elements of . + /// The type of the key returned by . + /// The type of the accumulator value. + /// An to aggregate over. + /// A function to extract the key for each element. + /// A factory for the initial accumulator value. + /// An accumulator function to be invoked on each element. + /// An to compare keys with. + /// An enumerable containing the aggregates corresponding to each key deriving from . + /// + /// This method is comparable to the GroupBy methods where each grouping is being aggregated into a single value + /// as opposed to allocating a collection for each group. + /// + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable> AggregateBy( + this IAsyncEnumerable source, + Func> keySelector, + Func> seedSelector, + Func> func, + IEqualityComparer? keyComparer = null) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(seedSelector); + ThrowHelper.ThrowIfNull(func); + + return Impl(source, keySelector, seedSelector, func, keyComparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func> keySelector, + Func> seedSelector, + Func> func, + IEqualityComparer? keyComparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + Dictionary dict = new(keyComparer); + + do + { + TSource value = enumerator.Current; + TKey key = await keySelector(value, cancellationToken).ConfigureAwait(false); + + dict[key] = await func( + dict.TryGetValue(key, out TAccumulate? acc) ? acc : await seedSelector(key, cancellationToken).ConfigureAwait(false), + value, + cancellationToken).ConfigureAwait(false); + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + + foreach (KeyValuePair countBy in dict) + { + yield return countBy; + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AllAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AllAsync.cs new file mode 100644 index 00000000000000..02c5f82a7e0b92 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AllAsync.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Determines whether all elements of a sequence satisfy a condition. + /// The type of the elements of source. + /// An that contains the elements to apply the predicate to. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// true if every element of the source sequence passes the test in the specified predicate, + /// or if the sequence is empty; otherwise, false. + /// + /// is . + /// is . + public static ValueTask AllAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), predicate); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + Func predicate) + { + await foreach (TSource element in source) + { + if (!predicate(element)) + { + return false; + } + } + + return true; + } + } + + /// Determines whether all elements of a sequence satisfy a condition. + /// The type of the elements of source. + /// An that contains the elements to apply the predicate to. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// true if every element of the source sequence passes the test in the specified predicate, + /// or if the sequence is empty; otherwise, false. + /// + /// is . + /// is . + public static ValueTask AllAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (!await predicate(element, cancellationToken).ConfigureAwait(false)) + { + return false; + } + } + + return true; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AnyAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AnyAsync.cs new file mode 100644 index 00000000000000..863b1ee2bbfa44 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AnyAsync.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Determines whether a sequence contains any elements. + /// The type of the elements of source. + /// The to check for emptiness. + /// The to monitor for cancellation requests. The default is . + /// true if the source sequence contains any elements; otherwise, false. + /// is . + public static ValueTask AnyAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + return await enumerator.MoveNextAsync().ConfigureAwait(false); + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Determines whether any element of a sequence satisfies a condition. + /// + /// An whose elements to apply the predicate to. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// true if the source sequence is not empty and at least one of its elements passes + /// the test in the specified predicate; otherwise, false. + /// + /// is . + /// is . + public static ValueTask AnyAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), predicate); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + Func predicate) + { + await foreach (TSource element in source) + { + if (predicate(element)) + { + return true; + } + } + + return false; + } + } + + /// Determines whether any element of a sequence satisfies a condition. + /// + /// An whose elements to apply the predicate to. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// true if the source sequence is not empty and at least one of its elements passes + /// the test in the specified predicate; otherwise, false. + /// + /// is . + /// is . + public static ValueTask AnyAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + return true; + } + } + + return false; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Append.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Append.cs new file mode 100644 index 00000000000000..299ef4403e1e61 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Append.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Appends a value to the end of the sequence. + /// The type of the elements of source. + /// A sequence of values. + /// The value to append to source. + /// A new sequence that ends with element. + /// is . + public static IAsyncEnumerable Append( + this IAsyncEnumerable source, + TSource element) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, element, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + TSource element, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + + yield return element; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AsyncEnumerable.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AsyncEnumerable.cs new file mode 100644 index 00000000000000..57c56af8226c8c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AsyncEnumerable.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + /// + /// Provides a set of static methods for querying objects that implement . + /// + public static partial class AsyncEnumerable + { + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AverageAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AverageAsync.cs new file mode 100644 index 00000000000000..2a5a08ab687c47 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/AverageAsync.cs @@ -0,0 +1,339 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Computes the average of a sequence of values. + /// A sequence of values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values. + /// is . + /// The sum of the elements in the sequence is larger than (via the returned task). + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + long sum = 0; + long count = 0; + await foreach (int item in source) + { + checked { sum += item; } + count++; + } + + if (count == 0) + { + ThrowHelper.ThrowNoElementsException(); + } + + return (double)sum / count; + } + } + + /// Computes the average of a sequence of values. + /// A sequence of values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values. + /// is . + /// The sum of the elements in the sequence is larger than (via the returned task). + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + long sum = 0; + long count = 0; + await foreach (long item in source) + { + checked { sum += item; } + count++; + } + + if (count == 0) + { + ThrowHelper.ThrowNoElementsException(); + } + + return (double)sum / count; + } + } + + /// Computes the average of a sequence of values. + /// A sequence of values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values. + /// is . + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + long count = 0; + await foreach (double item in source) + { + sum += item; + count++; + } + + if (count == 0) + { + ThrowHelper.ThrowNoElementsException(); + } + + return (float)(sum / count); + } + } + + /// Computes the average of a sequence of values. + /// A sequence of values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values. + /// is . + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + long count = 0; + await foreach (double item in source) + { + sum += item; + count++; + } + + if (count == 0) + { + ThrowHelper.ThrowNoElementsException(); + } + + return (double)sum / count; + } + } + + /// Computes the average of a sequence of values. + /// A sequence of values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values. + /// is . + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + decimal sum = 0; + long count = 0; + await foreach (decimal item in source) + { + sum += item; + count++; + } + + if (count == 0) + { + ThrowHelper.ThrowNoElementsException(); + } + + return sum / count; + } + } + + /// Computes the average of a sequence of nullable values. + /// A sequence of nullable values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values, or null if the source sequence is empty or contains only values that are null. + /// is . + /// The sum of the elements in the sequence is larger than (via the returned task). + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + long sum = 0; + long count = 0; + await foreach (int? item in source) + { + if (item is int value) + { + checked { sum += value; } + count++; + } + } + + return count != 0 ? (double)sum / count : null; + } + } + + /// Computes the average of a sequence of nullable values. + /// A sequence of nullable values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values, or null if the source sequence is empty or contains only values that are null. + /// is . + /// The sum of the elements in the sequence is larger than (via the returned task). + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + long sum = 0; + long count = 0; + await foreach (long? item in source) + { + if (item is long value) + { + checked { sum += value; } + count++; + } + } + + return count != 0 ? (double)sum / count : null; + } + } + + /// Computes the average of a sequence of nullable values. + /// A sequence of nullable values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values, or null if the source sequence is empty or contains only values that are null. + /// is . + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + long count = 0; + await foreach (float? item in source) + { + if (item is float value) + { + sum += value; + count++; + } + } + + return count != 0 ? (float)(sum / count) : null; + } + } + + /// Computes the average of a sequence of nullable values. + /// A sequence of nullable values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values, or null if the source sequence is empty or contains only values that are null. + /// is . + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + long count = 0; + await foreach (double? item in source) + { + if (item is double value) + { + sum += value; + count++; + } + } + + return count != 0 ? sum / count : null; + } + } + + /// Computes the average of a sequence of nullable values. + /// A sequence of nullable values to calculate the average of. + /// The to monitor for cancellation requests. The default is . + /// The average of the sequence of values, or null if the source sequence is empty or contains only values that are null. + /// is . + /// contains no elements (via the returned task). + public static ValueTask AverageAsync( + this IAsyncEnumerable source, CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + decimal sum = 0; + long count = 0; + await foreach (decimal? item in source) + { + if (item is decimal value) + { + sum += value; + count++; + } + } + + return count != 0 ? sum / count : null; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Cast.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Cast.cs new file mode 100644 index 00000000000000..37121ea5f3977e --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Cast.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + // TODO https://github.com/dotnet/runtime/issues/111717: + // Consider before shipping .NET 10 whether this can instead use extension everything to support any IAsyncEnumerable source. + // Right now it's limited because you can't cast an IAsyncEnumerable to IAsyncEnumerable. But the method with this + // shape is necessary to support query comprehensions with explicit types, e.g. `from string s in asyncEnumerable`. + + /// + /// Casts the elements of an to the specified type. + /// + /// The type to cast the elements of source to. + /// The that contains the elements to be cast to type . + /// An that contains each element of the source sequence cast to the type. + public static IAsyncEnumerable Cast( // satisfies the C# query-expression pattern + this IAsyncEnumerable source) + { + ThrowHelper.ThrowIfNull(source); + + return source is IAsyncEnumerable result ? + result : + Impl(source, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (object? item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return (TResult)item!; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Chunk.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Chunk.cs new file mode 100644 index 00000000000000..5392d953e1d960 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Chunk.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Split the elements of a sequence into chunks of size at most . + /// + /// Every chunk except the last will be of size . + /// The last chunk will contain the remaining elements and may be of a smaller size. + /// + /// The type of the elements of source. + /// An whose elements to chunk. + /// Maximum size of each chunk. + /// + /// An that contains the elements of the input sequence split into chunks of size . + /// + /// is . + /// is less than 1. + public static IAsyncEnumerable Chunk( + this IAsyncEnumerable source, + int size) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNegativeOrZero(size); + + return Chunk(source, size, default); + + async static IAsyncEnumerable Chunk( + IAsyncEnumerable source, + int size, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + // Before allocating anything, make sure there's at least one element. + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + // Now that we know we have at least one item, allocate an initial storage array. This is not + // the array we'll yield. It starts out small in order to avoid significantly overallocating + // when the source has many fewer elements than the chunk size. + int arraySize = Math.Min(size, 4); + int i; + do + { + var array = new TSource[arraySize]; + + // Store the first item. + array[0] = e.Current; + i = 1; + + if (size != array.Length) + { + // This is the first chunk. As we fill the array, grow it as needed. + for (; i < size && await e.MoveNextAsync().ConfigureAwait(false); i++) + { + if (i >= array.Length) + { + arraySize = (int)Math.Min((uint)size, 2 * (uint)array.Length); + Array.Resize(ref array, arraySize); + } + + array[i] = e.Current; + } + } + else + { + // For all but the first chunk, the array will already be correctly sized. + // We can just store into it until either it's full or MoveNext returns false. + TSource[] local = array; // avoid bounds checks by using cached local (`array` is lifted to iterator object as a field) + Debug.Assert(local.Length == size); + for (; (uint)i < (uint)local.Length && await e.MoveNextAsync().ConfigureAwait(false); i++) + { + local[i] = e.Current; + } + } + + if (i != array.Length) + { + Array.Resize(ref array, i); + } + + yield return array; + } + while (i >= size && await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Concat.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Concat.cs new file mode 100644 index 00000000000000..e29d2f45384907 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Concat.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Concatenates two sequences. + /// The type of the elements of the input sequences. + /// The first sequence to concatenate. + /// The sequence to concatenate to the first sequence. + /// An that contains the concatenated elements of the two input sequences. + /// is . + /// is . + public static IAsyncEnumerable Concat( + this IAsyncEnumerable first, IAsyncEnumerable second) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + + return Impl(first, second, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource item in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + + await foreach (TSource item in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ContainsAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ContainsAsync.cs new file mode 100644 index 00000000000000..49d9cee96b2041 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ContainsAsync.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Determines whether a sequence contains a specified element. + /// The type of the elements of source. + /// A sequence in which to locate a value. + /// The value to locate in the sequence. + /// An equality comparer to compare values. + /// The to monitor for cancellation requests. The default is . + /// true if the source sequence contains an element that has the specified value; otherwise, false. + /// is . + public static ValueTask ContainsAsync( + this IAsyncEnumerable source, + TSource value, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), value, comparer ?? EqualityComparer.Default); + + async static ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + TSource value, + IEqualityComparer comparer) + { + await foreach (TSource element in source) + { + if (comparer.Equals(element, value)) + { + return true; + } + } + + return false; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/CountAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/CountAsync.cs new file mode 100644 index 00000000000000..47088c6a1fe110 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/CountAsync.cs @@ -0,0 +1,226 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the number of elements in a sequence. + /// The type of the elements of source. + /// A sequence that contains elements to be counted. + /// The to monitor for cancellation requests. The default is . + /// The number of elements in the input sequence. + /// is . + /// The number of elements in source is larger than (via the returned task). + public static ValueTask CountAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + int count = 0; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + checked { count++; } + } + + return count; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the number of elements in a sequence satisfy a condition. + /// The type of the elements of source. + /// A sequence that contains elements to be tested and counted. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The number of elements in the input sequence that satisfy the condition in the predicate function. + /// is . + /// The number of elements that satisfy the condition is larger than (via the returned task). + public static ValueTask CountAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), predicate); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + Func predicate) + { + int count = 0; + await foreach (TSource element in source) + { + if (predicate(element)) + { + checked { count++; } + } + } + + return count; + } + } + + /// Returns the number of elements in a sequence satisfy a condition. + /// The type of the elements of source. + /// A sequence that contains elements to be tested and counted. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The number of elements in the input sequence that satisfy the condition in the predicate function. + /// is . + /// The number of elements that satisfy the condition is larger than (via the returned task). + public static ValueTask CountAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + int count = 0; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + checked { count++; } + } + } + + return count; + } + } + + /// Returns the number of elements in a sequence satisfy a condition. + /// The type of the elements of source. + /// A sequence that contains elements to be tested and counted. + /// The to monitor for cancellation requests. The default is . + /// The number of elements in the input sequence that satisfy the condition in the predicate function. + /// is . + public static ValueTask LongCountAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + long count = 0; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + count++; + } + + return count; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the number of elements in a sequence satisfy a condition. + /// The type of the elements of source. + /// A sequence that contains elements to be tested and counted. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The number of elements in the input sequence that satisfy the condition in the predicate function. + /// is . + public static ValueTask LongCountAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), predicate); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + Func predicate) + { + long count = 0; + await foreach (TSource element in source) + { + if (predicate(element)) + { + count++; + } + } + + return count; + } + } + + /// Returns the number of elements in a sequence satisfy a condition. + /// The type of the elements of source. + /// A sequence that contains elements to be tested and counted. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The number of elements in the input sequence that satisfy the condition in the predicate function. + /// is . + public static ValueTask LongCountAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + long count = 0; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + count++; + } + } + + return count; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/CountBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/CountBy.cs new file mode 100644 index 00000000000000..ba26521df5ca0a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/CountBy.cs @@ -0,0 +1,124 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the count of elements in the source sequence grouped by key. + /// The type of elements of . + /// The type of the key returned by . + /// A sequence that contains elements to be counted. + /// A function to extract the key for each element. + /// An to compare keys with. + /// An enumerable containing the frequencies of each key occurrence in . + /// is . + /// is . + public static IAsyncEnumerable> CountBy( + this IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? keyComparer = null) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, keyComparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, Func keySelector, IEqualityComparer? keyComparer, [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + Dictionary countsBy = new(keyComparer); + do + { + TSource value = enumerator.Current; + TKey key = keySelector(value); + +#if NET + ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _); + checked { currentCount++; } +#else + countsBy[key] = countsBy.TryGetValue(key, out int currentCount) ? checked(currentCount + 1) : 1; +#endif + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + + foreach (KeyValuePair countBy in countsBy) + { + yield return countBy; + } + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the count of elements in the source sequence grouped by key. + /// The type of elements of . + /// The type of the key returned by . + /// A sequence that contains elements to be counted. + /// A function to extract the key for each element. + /// An to compare keys with. + /// An enumerable containing the frequencies of each key occurrence in . + /// is . + /// is . + public static IAsyncEnumerable> CountBy( + this IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? keyComparer = null) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, keyComparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, Func> keySelector, IEqualityComparer? keyComparer, [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + Dictionary countsBy = new(keyComparer); + do + { + TSource value = enumerator.Current; + TKey key = await keySelector(value, cancellationToken).ConfigureAwait(false); + +#if NET + ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _); + checked { currentCount++; } +#else + countsBy[key] = countsBy.TryGetValue(key, out int currentCount) ? checked(currentCount + 1) : 1; +#endif + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + + foreach (KeyValuePair countBy in countsBy) + { + yield return countBy; + } + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/DefaultIfEmpty.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/DefaultIfEmpty.cs new file mode 100644 index 00000000000000..513407afdcabd8 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/DefaultIfEmpty.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the elements of the specified sequence or the type parameter's default if the sequence is empty. + /// The type of the elements of source. + /// The sequence to return a default value for if it is empty. + /// + /// An object that contains the default value for + /// the TSource type if source is empty; otherwise, source. + /// + /// is . + public static IAsyncEnumerable DefaultIfEmpty( + this IAsyncEnumerable source) => + DefaultIfEmpty(source, default); + + /// Returns the elements of the specified sequence or the specified value if the sequence is empty. + /// The type of the elements of source. + /// The sequence to return a default value for if it is empty. + /// The value to return if the sequence is empty. + /// + /// An object that contains the default value for + /// the TSource type if source is empty; otherwise, source. + /// + /// is . + public static IAsyncEnumerable DefaultIfEmpty( + this IAsyncEnumerable source, TSource defaultValue) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, defaultValue, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + TSource defaultValue, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + do + { + yield return e.Current; + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + else + { + yield return defaultValue; + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Distinct.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Distinct.cs new file mode 100644 index 00000000000000..ea3d75e87bed25 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Distinct.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns distinct elements from a sequence. + /// + /// The sequence to remove duplicate elements from. + /// An to compare values. + /// An that contains distinct elements from the source sequence. + /// is . + public static IAsyncEnumerable Distinct( + this IAsyncEnumerable source, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + HashSet set = new(comparer); + do + { + TSource element = e.Current; + if (set.Add(element)) + { + yield return element; + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/DistinctBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/DistinctBy.cs new file mode 100644 index 00000000000000..8c71a327abc514 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/DistinctBy.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns distinct elements from a sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to distinguish elements by. + /// The sequence to remove duplicate elements from. + /// A function to extract the key for each element. + /// An to compare keys. + /// An that contains distinct elements from the source sequence. + /// is . + /// + /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the information that is required to perform the action. The query represented by this method is not executed until the object is enumerated either by calling its `GetEnumerator` method directly or by using `foreach` in Visual C# or `For Each` in Visual Basic. + /// The method returns an unordered sequence that contains no duplicate values. If is , the default equality comparer, , is used to compare values. + /// + /// is . + /// is . + public static IAsyncEnumerable DistinctBy( + this IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + HashSet set = new(comparer); + do + { + TSource element = enumerator.Current; + if (set.Add(keySelector(element))) + { + yield return element; + } + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns distinct elements from a sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to distinguish elements by. + /// The sequence to remove duplicate elements from. + /// A function to extract the key for each element. + /// An to compare keys. + /// An that contains distinct elements from the source sequence. + /// is . + /// + /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the information that is required to perform the action. The query represented by this method is not executed until the object is enumerated either by calling its `GetEnumerator` method directly or by using `foreach` in Visual C# or `For Each` in Visual Basic. + /// The method returns an unordered sequence that contains no duplicate values. If is , the default equality comparer, , is used to compare values. + /// + /// is . + /// is . + public static IAsyncEnumerable DistinctBy( + this IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator enumerator = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + HashSet set = new(comparer); + do + { + TSource element = enumerator.Current; + if (set.Add(await keySelector(element, cancellationToken).ConfigureAwait(false))) + { + yield return element; + } + } + while (await enumerator.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await enumerator.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ElementAtAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ElementAtAsync.cs new file mode 100644 index 00000000000000..d0f54bfc97efbf --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ElementAtAsync.cs @@ -0,0 +1,187 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the element at a specified index in a sequence. + /// The type of the elements of source. + /// An to return an element from. + /// The index of the element to retrieve, which is either from the beginning or the end of the sequence. + /// The to monitor for cancellation requests. The default is . + /// The element at the specified position in the source sequence. + /// is . + /// is outside the bounds of the source sequence (via the returned task). + public static ValueTask ElementAtAsync( + this IAsyncEnumerable source, + int index, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return ElementAtOrDefaultAsync(source, index, throwIfNotFound: true, cancellationToken)!; + } + + /// Returns the element at a specified index in a sequence, or a default value if the index is out of range. + /// The type of the elements of source. + /// An to return an element from. + /// The index of the element to retrieve, which is either from the beginning or the end of the sequence. + /// The to monitor for cancellation requests. The default is . + /// + /// The default value of if is outside the bounds of the source sequence; otherwise, the + /// element at the specified position in the source sequence. + /// + /// is . + public static ValueTask ElementAtOrDefaultAsync( + this IAsyncEnumerable source, + int index, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return ElementAtOrDefaultAsync(source, index, throwIfNotFound: false, cancellationToken); + } + + /// Returns the element at a specified index in a sequence. + /// The type of the elements of . + /// An to return an element from. + /// The index of the element to retrieve, which is either from the start or the end. + /// The token to monitor for cancellation requests. The default value is . + /// is . + /// is outside the bounds of the sequence. + /// The element at the specified position in the sequence. + /// + /// If the type of implements , that implementation is used to obtain the element at the specified index. Otherwise, this method obtains the specified element. + /// This method throws an exception if is out of range. To instead return a default value when the specified index is out of range, use the ElementAtOrDefaultAsync method. + /// + /// is . + /// is outside the bounds of the source sequence (via the returned task). + public static ValueTask ElementAtAsync( + this IAsyncEnumerable source, + Index index, + CancellationToken cancellationToken = default) + { + if (!index.IsFromEnd) + { + return ElementAtAsync(source, index.Value, cancellationToken); + } + + ThrowHelper.ThrowIfNull(source); + + return ElementAtFromEndOrDefault(source, index.Value, throwIfNotFound: true, cancellationToken)!; + } + + /// Returns the element at a specified index in a sequence or a default value if the index is out of range. + /// The type of the elements of . + /// An to return an element from. + /// The index of the element to retrieve, which is either from the start or the end. + /// The token to monitor for cancellation requests. The default value is . + /// is . + /// if is outside the bounds of the sequence; otherwise, the element at the specified position in the sequence. + /// + /// If the type of implements , that implementation is used to obtain the element at the specified index. Otherwise, this method obtains the specified element. + /// The default value for reference and nullable types is . + /// + /// is . + public static ValueTask ElementAtOrDefaultAsync( + this IAsyncEnumerable source, + Index index, + CancellationToken cancellationToken = default) + { + if (!index.IsFromEnd) + { + return ElementAtOrDefaultAsync(source, index.Value, cancellationToken); + } + + ThrowHelper.ThrowIfNull(source); + + return ElementAtFromEndOrDefault(source, index.Value, throwIfNotFound: false, cancellationToken); + } + + private static async ValueTask ElementAtOrDefaultAsync( + IAsyncEnumerable source, + int index, + bool throwIfNotFound, + CancellationToken cancellationToken = default) + { + if (index >= 0) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (index == 0) + { + return e.Current; + } + + index--; + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + if (throwIfNotFound) + { + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(index)); + } + + return default; + } + + private static async ValueTask ElementAtFromEndOrDefault( + IAsyncEnumerable source, + int indexFromEnd, + bool throwIfNotFound, + CancellationToken cancellationToken) + { + if (indexFromEnd > 0) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + Queue queue = new(); + queue.Enqueue(e.Current); + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (queue.Count == indexFromEnd) + { + queue.Dequeue(); + } + + queue.Enqueue(e.Current); + } + + if (queue.Count == indexFromEnd) + { + return queue.Dequeue(); + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + if (throwIfNotFound) + { + ThrowHelper.ThrowArgumentOutOfRangeException("index"); + } + + return default; + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Empty.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Empty.cs new file mode 100644 index 00000000000000..fc1fb0dc7e9d78 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Empty.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Returns an empty that has the specified type argument. + /// + /// The type of the elements of the sequence. + /// An empty whose type argument is . + public static IAsyncEnumerable Empty() => EmptyAsyncEnumerable.Instance; + + private sealed class EmptyAsyncEnumerable : IAsyncEnumerable, IAsyncEnumerator + { + public static EmptyAsyncEnumerable Instance { get; } = new EmptyAsyncEnumerable(); + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => this; + + public ValueTask MoveNextAsync() => default; + + public TResult Current => default!; + + public ValueTask DisposeAsync() => default; + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Except.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Except.cs new file mode 100644 index 00000000000000..a1946b1c0b04fa --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Except.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Produces the set difference of two sequences. + /// The type of the elements of the input sequences. + /// An whose elements that are not also in second will be returned. + /// An whose elements that also occur in the first sequence will cause those elements to be removed from the returned sequence. + /// An to compare values. + /// A sequence that contains the set difference of the elements of two sequences. + /// is . + /// is . + public static IAsyncEnumerable Except( + this IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + + return Impl(first, second, comparer, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set = new(comparer); + + await foreach (TSource element in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + set.Add(element); + } + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(element)) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ExceptBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ExceptBy.cs new file mode 100644 index 00000000000000..cabc34ceeaaa7a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ExceptBy.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Produces the set difference of two sequences according to a specified key selector function. + /// + /// The type of the elements of the input sequence. + /// The type of key to identify elements by. + /// An whose keys that are not also in will be returned. + /// An whose keys that also occur in the first sequence will cause those elements to be removed from the returned sequence. + /// A function to extract the key for each element. + /// The to compare values. + /// A sequence that contains the set difference of the elements of two sequences. + /// is . + /// is . + /// is . + public static IAsyncEnumerable ExceptBy( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(first, second, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set = new(comparer); + + await foreach (TKey key in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + set.Add(key); + } + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + } + } + + /// + /// Produces the set difference of two sequences according to a specified key selector function. + /// + /// The type of the elements of the input sequence. + /// The type of key to identify elements by. + /// An whose keys that are not also in will be returned. + /// An whose keys that also occur in the first sequence will cause those elements to be removed from the returned sequence. + /// A function to extract the key for each element. + /// The to compare values. + /// A sequence that contains the set difference of the elements of two sequences. + /// is . + /// is . + /// is . + public static IAsyncEnumerable ExceptBy( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func> keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(first, second, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func> keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set = new(comparer); + + await foreach (TKey key in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + set.Add(key); + } + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(await keySelector(element, cancellationToken).ConfigureAwait(false))) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/FirstAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/FirstAsync.cs new file mode 100644 index 00000000000000..4b3c432803c7dc --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/FirstAsync.cs @@ -0,0 +1,281 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the first element of a sequence. + /// The type of the elements of source. + /// The to return the first element of. + /// The to monitor for cancellation requests. The default is . + /// The first element in the specified sequence. + /// is . + /// The source sequence is empty (via the returned task). + public static ValueTask FirstAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + return e.Current; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the first element in a sequence that satisfies a specified condition. + /// The type of the elements of source. + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The first element in the sequence that passes the test in the specified predicate function. + /// is . + /// is . + /// + /// The source sequence is empty, or no element in the sequence satisfies + /// the condition in predicate (via the returned task). + /// + public static ValueTask FirstAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), predicate); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + Func predicate) + { + await foreach (TSource item in source) + { + if (predicate(item)) + { + return item; + } + } + + ThrowHelper.ThrowNoElementsException(); + return default!; // unreachable + } + } + + /// Returns the first element in a sequence that satisfies a specified condition. + /// The type of the elements of source. + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The first element in the sequence that passes the test in the specified predicate function. + /// is . + /// is . + /// + /// The source sequence is empty, or no element in the sequence satisfies + /// the condition in predicate (via the returned task). + /// + public static ValueTask FirstAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken) + { + await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (await predicate(item, cancellationToken).ConfigureAwait(false)) + { + return item; + } + } + + ThrowHelper.ThrowNoElementsException(); + return default!; // unreachable + } + } + + /// Returns the first element of a sequence, or the default value of if the sequence contains no elements. + /// The type of the elements of . + /// The to return the first element of. + /// The to monitor for cancellation requests. The default is . + /// The default value of if is empty; otherwise, the first element in . + /// is . + public static ValueTask FirstOrDefaultAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) => + FirstOrDefaultAsync(source, default(TSource), cancellationToken)!; + + /// Returns the first element of a sequence, or a default value if the sequence contains no elements. + /// The type of the elements of . + /// The to return the first element of. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// if is empty; otherwise, the first element in . + /// is . + public static ValueTask FirstOrDefaultAsync( + this IAsyncEnumerable source, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + TSource defaultValue, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + return await e.MoveNextAsync().ConfigureAwait(false) ? e.Current : defaultValue; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the first element of the sequence that satisfies a condition or a default value if no such element is found. + /// + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// The default value of if source is empty or if no element passes the test specified + /// by predicate; otherwise, the first element in source that passes the test specified by predicate. + /// + /// is . + /// is . + public static ValueTask FirstOrDefaultAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) => + FirstOrDefaultAsync(source, predicate!, default, cancellationToken); + + /// Returns the first element of the sequence that satisfies a condition or a default value if no such element is found. + /// + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// The default value of if source is empty or if no element passes the test specified + /// by predicate; otherwise, the first element in source that passes the test specified by predicate. + /// + /// is . + /// is . + public static ValueTask FirstOrDefaultAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) => + FirstOrDefaultAsync(source, predicate!, default, cancellationToken); + + /// Returns the first element of the sequence that satisfies a condition or a default value if no such element is found. + /// The type of the elements of . + /// An to return an element from. + /// A function to test each element for a condition. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// if is empty or if no element passes the test specified by ; otherwise, the first element in that passes the test specified by . + /// is . + /// is . + public static ValueTask FirstOrDefaultAsync( + this IAsyncEnumerable source, + Func predicate, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), predicate, defaultValue); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source, + Func predicate, + TSource defaultValue) + { + await foreach (TSource item in source) + { + if (predicate(item)) + { + return item; + } + } + + return defaultValue; + } + } + + /// Returns the first element of the sequence that satisfies a condition or a default value if no such element is found. + /// The type of the elements of . + /// An to return an element from. + /// A function to test each element for a condition. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// if is empty or if no element passes the test specified by ; otherwise, the first element in that passes the test specified by . + /// is . + /// is . + public static ValueTask FirstOrDefaultAsync( + this IAsyncEnumerable source, + Func> predicate, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + TSource defaultValue, + CancellationToken cancellationToken) + { + await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (await predicate(item, cancellationToken).ConfigureAwait(false)) + { + return item; + } + } + + return defaultValue; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupBy.cs new file mode 100644 index 00000000000000..8eac9561756b5e --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupBy.cs @@ -0,0 +1,449 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Groups the elements of a sequence according to a specified key selector function. + /// The type of the elements of source. + /// The type of the key returned by . + /// An of elements to group. + /// A function to extract the key for each element. + /// An to compare keys. + /// + /// An where each + /// contains a sequence of objects and a key. + /// + /// is . + /// is . + public static IAsyncEnumerable> GroupBy( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + foreach (IGrouping item in await ToLookupAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + + /// Groups the elements of a sequence according to a specified key selector function. + /// The type of the elements of source. + /// The type of the key returned by . + /// An of elements to group. + /// A function to extract the key for each element. + /// An to compare keys. + /// + /// An where each + /// contains a sequence of objects and a key. + /// + /// is . + /// is . + public static IAsyncEnumerable> GroupBy( + this IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + foreach (IGrouping item in await ToLookupAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + + /// + /// Groups the elements of a sequence according to a key selector function. The keys + /// are compared by using a comparer and each group's elements are projected by using + /// a specified function. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the elements in the . + /// An of elements to group. + /// A function to extract the key for each element. + /// A function to map each source element to an element in an . + /// An to compare keys. + /// + /// An where each + /// contains a sequence of objects of type and a key. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable> GroupBy( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + + return Impl(source, keySelector, elementSelector, comparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + foreach (IGrouping item in await ToLookupAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + + /// + /// Groups the elements of a sequence according to a key selector function. The keys + /// are compared by using a comparer and each group's elements are projected by using + /// a specified function. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the elements in the . + /// An of elements to group. + /// A function to extract the key for each element. + /// A function to map each source element to an element in an . + /// An to compare keys. + /// + /// An where each + /// contains a sequence of objects of type and a key. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable> GroupBy( + this IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + + return Impl(source, keySelector, elementSelector, comparer, default); + + static async IAsyncEnumerable> Impl( + IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + foreach (IGrouping item in await ToLookupAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + + /// + /// Groups the elements of a sequence according to a specified key selector function + /// and creates a result value from each group and its key. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the result value returned by resultSelector. + /// An of elements to group. + /// A function to extract the key for each element. + /// A function to create a result value from each group. + /// An to compare keys. + /// + /// A collection of elements of type where each element represents + /// a projection over a group and its key. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable GroupBy( + this IAsyncEnumerable source, + Func keySelector, + Func, TResult> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, keySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func keySelector, + Func, TResult> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var lookup = (AsyncLookup)await ToLookupAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); + foreach (TResult item in lookup.ApplyResultSelector(resultSelector)) + { + yield return item; + } + } + } + + /// + /// Groups the elements of a sequence according to a specified key selector function + /// and creates a result value from each group and its key. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the result value returned by resultSelector. + /// An of elements to group. + /// A function to extract the key for each element. + /// A function to create a result value from each group. + /// An to compare keys. + /// + /// A collection of elements of type where each element represents + /// a projection over a group and its key. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable GroupBy( + this IAsyncEnumerable source, + Func> keySelector, + Func, CancellationToken, ValueTask> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, keySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> keySelector, + Func, CancellationToken, ValueTask> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var lookup = (AsyncLookup)await ToLookupAsync(source, keySelector, comparer, cancellationToken).ConfigureAwait(false); + await foreach (TResult item in lookup.ApplyResultSelector(resultSelector, cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + + /// + /// Groups the elements of a sequence according to a specified key selector function + /// and creates a result value from each group and its key. Key values are compared + /// by using a specified comparer, and the elements of each group are projected by + /// using a specified function. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the elements in each . + /// The type of the result value returned by . + /// An of elements to group. + /// A function to extract the key for each element. + /// A function to map each source element to an element in an . + /// A function to create a result value from each group. + /// An to compare keys. + /// A collection of elements of type where each element represents a projection over a group and its key. + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable GroupBy( + this IAsyncEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, keySelector, elementSelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func keySelector, + Func elementSelector, + Func, TResult> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var lookup = (AsyncLookup)await ToLookupAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false); + foreach (TResult item in lookup.ApplyResultSelector(resultSelector)) + { + yield return item; + } + } + } + + /// + /// Groups the elements of a sequence according to a specified key selector function + /// and creates a result value from each group and its key. Key values are compared + /// by using a specified comparer, and the elements of each group are projected by + /// using a specified function. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the elements in each . + /// The type of the result value returned by . + /// An of elements to group. + /// A function to extract the key for each element. + /// A function to map each source element to an element in an . + /// A function to create a result value from each group. + /// An to compare keys. + /// A collection of elements of type where each element represents a projection over a group and its key. + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable GroupBy( + this IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + Func, CancellationToken, ValueTask> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, keySelector, elementSelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + Func, CancellationToken, ValueTask> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + var lookup = (AsyncLookup)await ToLookupAsync(source, keySelector, elementSelector, comparer, cancellationToken).ConfigureAwait(false); + await foreach (TResult item in lookup.ApplyResultSelector(resultSelector, cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + + internal sealed class Grouping : IGrouping, IList + { + internal readonly TKey _key; + internal readonly int _hashCode; + internal TElement[] _elements; + internal int _count; + internal Grouping? _hashNext; + internal Grouping? _next; + + internal Grouping(TKey key, int hashCode) + { + _key = key; + _hashCode = hashCode; + _elements = new TElement[1]; + } + + internal void Add(TElement element) + { + if (_elements.Length == _count) + { + Array.Resize(ref _elements, checked(_count * 2)); + } + + _elements[_count] = element; + _count++; + } + + internal void Trim() + { + if (_elements.Length != _count) + { + Array.Resize(ref _elements, _count); + } + } + + public IEnumerator GetEnumerator() + { + Debug.Assert(_count > 0, "A grouping should only have been created if an element was being added to it."); + for (int i = 0; i < _count; i++) + { + yield return _elements[i]; + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public TKey Key => _key; + + int ICollection.Count => _count; + + bool ICollection.IsReadOnly => true; + + void ICollection.Add(TElement item) => throw new NotSupportedException(); + + void ICollection.Clear() => throw new NotSupportedException(); + + bool ICollection.Contains(TElement item) => Array.IndexOf(_elements, item, 0, _count) >= 0; + + void ICollection.CopyTo(TElement[] array, int arrayIndex) => + Array.Copy(_elements, 0, array, arrayIndex, _count); + + bool ICollection.Remove(TElement item) => throw new NotSupportedException(); + + int IList.IndexOf(TElement item) => Array.IndexOf(_elements, item, 0, _count); + + void IList.Insert(int index, TElement item) => throw new NotSupportedException(); + + void IList.RemoveAt(int index) => throw new NotSupportedException(); + + TElement IList.this[int index] + { + get + { + if ((uint)index >= (uint)_count) + { + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(index)); + } + + return _elements[index]; + } + set => throw new NotSupportedException(); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs new file mode 100644 index 00000000000000..3b2d56147bb31d --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/GroupJoin.cs @@ -0,0 +1,154 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Correlates the elements of two sequences based on key equality and groups the results. + /// + /// + /// + /// + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// + /// A function to create a result element from an element from the first sequence + /// and a collection of matching elements from the second sequence. + /// + /// An to use to hash and compare keys. + /// + /// An that contains elements of type + /// that are obtained by performing a grouped join on two sequences. + /// + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable GroupJoin( // satisfies the C# query-expression pattern + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func, TResult> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup lookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + do + { + TOuter item = e.Current; + yield return resultSelector(item, lookup[outerKeySelector(item)]); + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Correlates the elements of two sequences based on key equality and groups the results. + /// + /// + /// + /// + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// + /// A function to create a result element from an element from the first sequence + /// and a collection of matching elements from the second sequence. + /// + /// An to use to hash and compare keys. + /// + /// An that contains elements of type + /// that are obtained by performing a grouped join on two sequences. + /// + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable GroupJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func, CancellationToken, ValueTask> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func, CancellationToken, ValueTask> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup lookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + do + { + TOuter item = e.Current; + yield return await resultSelector( + item, + lookup[await outerKeySelector(item, cancellationToken).ConfigureAwait(false)], + cancellationToken).ConfigureAwait(false); + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Index.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Index.cs new file mode 100644 index 00000000000000..758ca28b6b5d72 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Index.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns an enumerable that incorporates the element's index into a tuple. + /// The type of the elements of . + /// The source enumerable providing the elements. + /// An enumerable that incorporates each element index into a tuple. + /// is . + public static IAsyncEnumerable<(int Index, TSource Item)> Index( + this IAsyncEnumerable source) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, default); + + static async IAsyncEnumerable<(int Index, TSource Item)> Impl( + IAsyncEnumerable source, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return (checked(++index), element); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Intersect.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Intersect.cs new file mode 100644 index 00000000000000..97382080923145 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Intersect.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Produces the set intersection of two sequences. + /// The type of the elements of the input sequences. + /// An whose distinct elements that also appear in second will be returned. + /// An whose distinct elements that also appear in the first sequence will be returned. + /// An to compare values. + /// A sequence that contains the elements that form the set intersection of two sequences. + /// is . + /// is . + public static IAsyncEnumerable Intersect( + this IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + + return Impl(first, second, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set; + IAsyncEnumerator e = second.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + set = new(comparer); + do + { + set.Add(e.Current); + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Remove(element)) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/IntersectBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/IntersectBy.cs new file mode 100644 index 00000000000000..2960209f2f43c5 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/IntersectBy.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Produces the set intersection of two sequences according to a specified key selector function. + /// The type of the elements of the input sequences. + /// The type of key to identify elements by. + /// An whose distinct elements that also appear in will be returned. + /// An whose distinct elements that also appear in the first sequence will be returned. + /// A function to extract the key for each element. + /// An to compare keys. + /// A sequence that contains the elements that form the set intersection of two sequences. + /// or is . + /// + /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the information that is required to perform the action. The query represented by this method is not executed until the object is enumerated either by calling its `GetEnumerator` method directly or by using `foreach` in Visual C# or `For Each` in Visual Basic. + /// The intersection of two sets A and B is defined as the set that contains all the elements of A that also appear in B, but no other elements. + /// When the object returned by this method is enumerated, `Intersect` yields distinct elements occurring in both sequences in the order in which they appear in . + /// If is , the default equality comparer, , is used to compare values. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable IntersectBy( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(first, second, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set; + IAsyncEnumerator e = second.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + set = new(comparer); + do + { + set.Add(e.Current); + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Remove(keySelector(element))) + { + yield return element; + } + } + } + } + + /// Produces the set intersection of two sequences according to a specified key selector function. + /// The type of the elements of the input sequences. + /// The type of key to identify elements by. + /// An whose distinct elements that also appear in will be returned. + /// An whose distinct elements that also appear in the first sequence will be returned. + /// A function to extract the key for each element. + /// An to compare keys. + /// A sequence that contains the elements that form the set intersection of two sequences. + /// or is . + /// + /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the information that is required to perform the action. The query represented by this method is not executed until the object is enumerated either by calling its `GetEnumerator` method directly or by using `foreach` in Visual C# or `For Each` in Visual Basic. + /// The intersection of two sets A and B is defined as the set that contains all the elements of A that also appear in B, but no other elements. + /// When the object returned by this method is enumerated, `Intersect` yields distinct elements occurring in both sequences in the order in which they appear in . + /// If is , the default equality comparer, , is used to compare values. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable IntersectBy( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func> keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(first, second, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func> keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set; + IAsyncEnumerator e = second.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + set = new(comparer); + do + { + set.Add(e.Current); + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Remove(await keySelector(element, cancellationToken).ConfigureAwait(false))) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Join.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Join.cs new file mode 100644 index 00000000000000..d6acc4a3012d35 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Join.cs @@ -0,0 +1,168 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Correlates the elements of two sequences based on matching keys. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The type of the result elements. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// A function to create a result element from two matching elements. + /// An to use to hash and compare keys. + /// + /// An that has elements of type + /// that are obtained by performing an inner join on two sequences. + /// + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable Join( // satisfies the C# query-expression pattern + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup lookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + if (lookup.Count != 0) + { + do + { + TOuter item = e.Current; + Grouping? g = lookup.GetGrouping(outerKeySelector(item), create: false); + if (g is not null) + { + int count = g._count; + TInner[] elements = g._elements; + for (int i = 0; i != count; ++i) + { + yield return resultSelector(item, elements[i]); + } + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Correlates the elements of two sequences based on matching keys. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The type of the result elements. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// A function to create a result element from two matching elements. + /// An to use to hash and compare keys. + /// + /// An that has elements of type + /// that are obtained by performing an inner join on two sequences. + /// + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable Join( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup lookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + if (lookup.Count != 0) + { + do + { + TOuter item = e.Current; + Grouping? g = lookup.GetGrouping(await outerKeySelector(item, cancellationToken).ConfigureAwait(false), create: false); + if (g is not null) + { + int count = g._count; + TInner[] elements = g._elements; + for (int i = 0; i != count; ++i) + { + yield return await resultSelector(item, elements[i], cancellationToken).ConfigureAwait(false); + } + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/LastAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/LastAsync.cs new file mode 100644 index 00000000000000..768104af75dd68 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/LastAsync.cs @@ -0,0 +1,373 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the last element of a sequence. + /// The type of the elements of source. + /// An to return the last element of. + /// The to monitor for cancellation requests. The default is . + /// The value at the last position in the source sequence. + /// is . + /// The source sequence is empty (via the returned task). + public static ValueTask LastAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + TSource result; + do + { + result = e.Current; + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the last element of a sequence that satisfies a specified condition. + /// The type of the elements of source. + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The last element in the sequence that passes the test in the specified predicate function. + /// is . + /// is . + /// + /// The source sequence is empty, or no element in the sequence satisfies + /// the condition in predicate (via the returned task). + /// + public static ValueTask LastAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (predicate(element)) + { + TSource result = element; + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + element = e.Current; + if (predicate(element)) + { + result = element; + } + } + + return result; + } + } + + ThrowHelper.ThrowNoMatchException(); + return default!; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the last element of a sequence that satisfies a specified condition. + /// The type of the elements of source. + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The last element in the sequence that passes the test in the specified predicate function. + /// is . + /// is . + /// + /// The source sequence is empty, or no element in the sequence satisfies + /// the condition in predicate (via the returned task). + /// + public static ValueTask LastAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + TSource result = element; + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + element = e.Current; + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + result = element; + } + } + + return result; + } + } + + ThrowHelper.ThrowNoMatchException(); + return default!; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the last element of a sequence, or a default value if the sequence contains no elements. + /// The type of the elements of source. + /// An to return an element from. + /// The to monitor for cancellation requests. The default is . + /// + /// The default value of if the source sequence is empty; + /// otherwise, the last element in the . + /// + /// is . + public static ValueTask LastOrDefaultAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) => + LastOrDefaultAsync(source, default(TSource), cancellationToken); + + /// Returns the last element of a sequence, or a default value if the sequence contains no elements. + /// The type of the elements of . + /// An to return the last element of. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// if the source sequence is empty; otherwise, the last element in the . + /// is . + public static ValueTask LastOrDefaultAsync( + this IAsyncEnumerable source, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, TSource defaultValue, CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + TSource result = defaultValue; + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + do + { + result = e.Current; + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the last element of a sequence that satisfies a condition or a default value if no such element is found. + /// The type of the elements of . + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The default value of if the sequence is empty or if no elements pass the test in the predicate function; otherwise, the last element that passes the test in the predicate function. + /// is . + /// is . + public static ValueTask LastOrDefaultAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) => + LastOrDefaultAsync(source, predicate!, default, cancellationToken); + + /// Returns the last element of a sequence that satisfies a condition or a default value if no such element is found. + /// The type of the elements of . + /// An to return an element from. + /// A function to test each element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The default value of if the sequence is empty or if no elements pass the test in the predicate function; otherwise, the last element that passes the test in the predicate function. + /// is . + /// is . + public static ValueTask LastOrDefaultAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) => + LastOrDefaultAsync(source, predicate!, default, cancellationToken); + + /// Returns the last element of a sequence that satisfies a condition or a default value if no such element is found. + /// The type of the elements of . + /// An to return an element from. + /// A function to test each element for a condition. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// if the sequence is empty or if no elements pass the test in the predicate function; otherwise, the last element that passes the test in the predicate function. + /// is . + /// is . + public static ValueTask LastOrDefaultAsync( + this IAsyncEnumerable source, + Func predicate, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func predicate, + TSource defaultValue, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + TSource result = defaultValue; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (predicate(element)) + { + result = element; + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + element = e.Current; + if (predicate(element)) + { + result = element; + } + } + + break; + } + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the last element of a sequence that satisfies a condition or a default value if no such element is found. + /// The type of the elements of . + /// An to return an element from. + /// A function to test each element for a condition. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// if the sequence is empty or if no elements pass the test in the predicate function; otherwise, the last element that passes the test in the predicate function. + /// is . + /// is . + public static ValueTask LastOrDefaultAsync( + this IAsyncEnumerable source, + Func> predicate, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, Func> predicate, TSource defaultValue, CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + TSource result = defaultValue; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + result = element; + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + element = e.Current; + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + result = element; + } + } + + break; + } + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/LeftJoin.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/LeftJoin.cs new file mode 100644 index 00000000000000..acce0f765694ad --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/LeftJoin.cs @@ -0,0 +1,164 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Correlates the elements of two sequences based on matching keys. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// A function to create a result element from two matching elements. + /// An to use to hash and compare keys. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The type of the result elements. + /// An that has elements of type that are obtained by performing a left outer join on two sequences. + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable LeftJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup innerLookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + do + { + TOuter item = e.Current; + Grouping? g = innerLookup.GetGrouping(outerKeySelector(item), create: false); + if (g is null) + { + yield return resultSelector(item, default); + } + else + { + int count = g._count; + TInner[] elements = g._elements; + for (int i = 0; i != count; ++i) + { + yield return resultSelector(item, elements[i]); + } + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Correlates the elements of two sequences based on matching keys. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// A function to create a result element from two matching elements. + /// An to use to hash and compare keys. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The type of the result elements. + /// An that has elements of type that are obtained by performing a left outer join on two sequences. + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable LeftJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = outer.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup innerLookup = await AsyncLookup.CreateForJoinAsync(inner, innerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + do + { + TOuter item = e.Current; + Grouping? g = innerLookup.GetGrouping(await outerKeySelector(item, cancellationToken).ConfigureAwait(false), create: false); + if (g is null) + { + yield return await resultSelector(item, default, cancellationToken).ConfigureAwait(false); + } + else + { + int count = g._count; + TInner[] elements = g._elements; + for (int i = 0; i != count; ++i) + { + yield return await resultSelector(item, elements[i], cancellationToken).ConfigureAwait(false); + } + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MaxAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MaxAsync.cs new file mode 100644 index 00000000000000..8dcf84371519aa --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MaxAsync.cs @@ -0,0 +1,272 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the maximum value in a generic sequence. + /// The type of the elements of . + /// A sequence of values to determine the maximum value of. + /// The to compare values. + /// The to monitor for cancellation requests. The default is . + /// The maximum value in the sequence. + /// is . + /// No object in implements the or interface (via the returned task). + /// + /// If type implements , the method uses that implementation to compare values. Otherwise, if type implements , that implementation is used to compare values. + /// If is a reference type and the source sequence is empty or contains only values that are , this method returns . + /// + public static ValueTask MaxAsync( + this IAsyncEnumerable source, + IComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + comparer ??= Comparer.Default; + + // Special-case float/double/float?/double? to maintain compatibility + // with System.Linq.Enumerable implementations. +#pragma warning disable CA2012 // Use ValueTasks correctly + if (typeof(TSource) == typeof(float) && comparer == Comparer.Default) + { + return (ValueTask)(object)MaxAsync((IAsyncEnumerable)(object)source, cancellationToken); + } + + if (typeof(TSource) == typeof(double) && comparer == Comparer.Default) + { + return (ValueTask)(object)MaxAsync((IAsyncEnumerable)(object)source, cancellationToken); + } + + if (typeof(TSource) == typeof(float?) && comparer == Comparer.Default) + { + return (ValueTask)(object)MaxAsync((IAsyncEnumerable)(object)source, cancellationToken); + } + + if (typeof(TSource) == typeof(double?) && comparer == Comparer.Default) + { + return (ValueTask)(object)MaxAsync((IAsyncEnumerable)(object)source, cancellationToken); + } +#pragma warning restore CA2012 + + return Impl(source, comparer, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + IComparer comparer, + CancellationToken cancellationToken) + { + TSource? value = default; + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (default(TSource) is null) + { + do + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + return value; + } + + value = e.Current; + } + while (value is null); + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource next = e.Current; + if (next is not null && comparer.Compare(next, value) > 0) + { + value = next; + } + } + } + else + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + value = e.Current; + if (comparer == Comparer.Default) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource next = e.Current; + if (Comparer.Default.Compare(next, value) > 0) + { + value = next; + } + } + } + else + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource next = e.Current; + if (comparer.Compare(next, value) > 0) + { + value = next; + } + } + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + + return value; + } + } + + /// Returns the maximum value in a sequence of values. + /// A sequence of values to determine the maximum value of. + /// The to monitor for cancellation requests. The default is . + /// The maximum value in the sequence. + private static async ValueTask MaxAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + // NaN is ordered less than all other values. We need to do explicit checks to ensure this, + // but once we've found a value that is not NaN we need no longer worry about it, + // so first loop until such a value is found (or not, as the case may be). + float value = e.Current; + while (float.IsNaN(value)) + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + return value; + } + + value = e.Current; + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + float x = e.Current; + if (x > value) + { + value = x; + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + /// Returns the maximum value in a sequence of values. + /// A sequence of values to determine the maximum value of. + /// The to monitor for cancellation requests. The default is . + /// The maximum value in the sequence. + private static async ValueTask MaxAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + // NaN is ordered less than all other values. We need to do explicit checks to ensure this, + // but once we've found a value that is not NaN we need no longer worry about it, + // so first loop until such a value is found (or not, as the case may be). + double value = e.Current; + while (double.IsNaN(value)) + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + return value; + } + + value = e.Current; + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + double x = e.Current; + if (x > value) + { + value = x; + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + /// Returns the maximum value in a sequence of nullable values. + /// A sequence of nullable values to determine the maximum value of. + /// The to monitor for cancellation requests. The default is . + /// The maximum value in the sequence. + private static async ValueTask MaxAsync(IAsyncEnumerable source, CancellationToken cancellationToken) + { + float? value = null; + await foreach (float? x in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (x is null) + { + continue; + } + + if (value is null || x > value || float.IsNaN((float)value)) + { + value = x; + } + } + + return value; + } + + /// Returns the maximum value in a sequence of nullable values. + /// A sequence of nullable values to determine the maximum value of. + /// The to monitor for cancellation requests. The default is . + /// The maximum value in the sequence. + private static async ValueTask MaxAsync(IAsyncEnumerable source, CancellationToken cancellationToken) + { + double? value = null; + await foreach (double? x in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (x is null) + { + continue; + } + + if (value is null || x > value || double.IsNaN((double)value)) + { + value = x; + } + } + + return value; + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MaxByAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MaxByAsync.cs new file mode 100644 index 00000000000000..4a3d3ecd9484c2 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MaxByAsync.cs @@ -0,0 +1,244 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the maximum value in a generic sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to compare elements by. + /// A sequence of values to determine the maximum value of. + /// A function to extract the key for each element. + /// The to compare keys. + /// The to monitor for cancellation requests. The default is . + /// The value with the maximum key in the sequence. + /// is . + /// No key extracted from implements the or interface. + /// + /// If is a reference type and the source sequence is empty or contains only values that are , this method returns . + /// + public static ValueTask MaxByAsync( + this IAsyncEnumerable source, + Func keySelector, + IComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer ?? Comparer.Default, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func keySelector, + IComparer comparer, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + if (default(TSource) is not null) + { + ThrowHelper.ThrowNoElementsException(); + } + + return default; + } + + TSource value = e.Current; + TKey key = keySelector(value); + + if (default(TKey) is null) + { + if (key is null) + { + TSource firstValue = value; + + do + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + // All keys are null, surface the first element. + return firstValue; + } + + value = e.Current; + key = keySelector(value); + } + while (key is null); + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (nextKey is not null && comparer.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + if (comparer == Comparer.Default) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (Comparer.Default.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (comparer.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the maximum value in a generic sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to compare elements by. + /// A sequence of values to determine the maximum value of. + /// A function to extract the key for each element. + /// The to compare keys. + /// The to monitor for cancellation requests. The default is . + /// The value with the maximum key in the sequence. + /// is . + /// No key extracted from implements the or interface. + /// + /// If is a reference type and the source sequence is empty or contains only values that are , this method returns . + /// + public static ValueTask MaxByAsync( + this IAsyncEnumerable source, + Func> keySelector, + IComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer ?? Comparer.Default, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> keySelector, + IComparer comparer, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + if (default(TSource) is not null) + { + ThrowHelper.ThrowNoElementsException(); + } + + return default; + } + + TSource value = e.Current; + TKey key = await keySelector(value, cancellationToken).ConfigureAwait(false); + + if (default(TKey) is null) + { + if (key is null) + { + TSource firstValue = value; + + do + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + // All keys are null, surface the first element. + return firstValue; + } + + value = e.Current; + key = await keySelector(value, cancellationToken).ConfigureAwait(false); + } + while (key is null); + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = await keySelector(nextValue, cancellationToken).ConfigureAwait(false); + if (nextKey is not null && comparer.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + if (comparer == Comparer.Default) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = await keySelector(nextValue, cancellationToken).ConfigureAwait(false); + if (Comparer.Default.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = await keySelector(nextValue, cancellationToken).ConfigureAwait(false); + if (comparer.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MinAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MinAsync.cs new file mode 100644 index 00000000000000..34e4ceec262bdf --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MinAsync.cs @@ -0,0 +1,283 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the minimum value in a generic sequence. + /// The type of the elements of . + /// A sequence of values to determine the minimum value of. + /// The to compare values. + /// The to monitor for cancellation requests. The default is . + /// The minimum value in the sequence. + /// is . + /// No object in implements the or interface. + /// + /// If type implements , the method uses that implementation to compare values. Otherwise, if type implements , that implementation is used to compare values. + /// If is a reference type and the source sequence is empty or contains only values that are , this method returns . + /// + public static ValueTask MinAsync( + this IAsyncEnumerable source, + IComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + comparer ??= Comparer.Default; + + // Special-case float/double/float?/double? to maintain compatibility + // with System.Linq.Enumerable implementations. +#pragma warning disable CA2012 // Use ValueTasks correctly + if (typeof(TSource) == typeof(float) && comparer == Comparer.Default) + { + return (ValueTask)(object)MinAsync((IAsyncEnumerable)(object)source, cancellationToken); + } + + if (typeof(TSource) == typeof(double) && comparer == Comparer.Default) + { + return (ValueTask)(object)MinAsync((IAsyncEnumerable)(object)source, cancellationToken); + } + + if (typeof(TSource) == typeof(float?) && comparer == Comparer.Default) + { + return (ValueTask)(object)MinAsync((IAsyncEnumerable)(object)source, cancellationToken); + } + + if (typeof(TSource) == typeof(double?) && comparer == Comparer.Default) + { + return (ValueTask)(object)MinAsync((IAsyncEnumerable)(object)source, cancellationToken); + } +#pragma warning restore CA2012 + + return Impl(source, comparer, cancellationToken); + + static async ValueTask Impl(IAsyncEnumerable source, IComparer comparer, CancellationToken cancellationToken) + { + TSource? value = default; + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (default(TSource) is null) + { + do + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + return value; + } + + value = e.Current; + } + while (value is null); + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource next = e.Current; + if (next is not null && comparer.Compare(next, value) < 0) + { + value = next; + } + } + } + else + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + value = e.Current; + if (comparer == Comparer.Default) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource next = e.Current; + if (Comparer.Default.Compare(next, value) < 0) + { + value = next; + } + } + } + else + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource next = e.Current; + if (comparer.Compare(next, value) < 0) + { + value = next; + } + } + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the minimum value in a sequence of values. + /// A sequence of values to determine the minimum value of. + /// The to monitor for cancellation requests. The default is . + /// The minimum value in the sequence. + private static async ValueTask MinAsync( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + float value = e.Current; + if (float.IsNaN(value)) + { + return value; + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + float x = e.Current; + if (x < value) + { + value = x; + } + + // Normally NaN < anything is false, as is anything < NaN + // However, this leads to some irksome outcomes in Min and Max. + // If we use those semantics then Min(NaN, 5.0) is NaN, but + // Min(5.0, NaN) is 5.0! To fix this, we impose a total + // ordering where NaN is smaller than every value, including + // negative infinity. Not testing for NaN therefore isn't an option, but since we + // can't find a smaller value, we can short-circuit. + else if (float.IsNaN(x)) + { + return x; + } + } + + return value; + + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + /// Returns the minimum value in a sequence of values. + /// A sequence of values to determine the minimum value of. + /// The to monitor for cancellation requests. The default is . + /// The minimum value in the sequence. + private static async ValueTask MinAsync( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + double value = e.Current; + if (double.IsNaN(value)) + { + return value; + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + double x = e.Current; + if (x < value) + { + value = x; + } + + // Normally NaN < anything is false, as is anything < NaN + // However, this leads to some irksome outcomes in Min and Max. + // If we use those semantics then Min(NaN, 5.0) is NaN, but + // Min(5.0, NaN) is 5.0! To fix this, we impose a total + // ordering where NaN is smaller than every value, including + // negative infinity. Not testing for NaN therefore isn't an option, but since we + // can't find a smaller value, we can short-circuit. + else if (double.IsNaN(x)) + { + return x; + } + } + + return value; + + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + /// Returns the minimum value in a sequence of nullable values. + /// A sequence of nullable values to determine the minimum value of. + /// The to monitor for cancellation requests. The default is . + /// The minimum value in the sequence. + private static async ValueTask MinAsync( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + float? value = null; + await foreach (float? x in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (x is null) + { + continue; + } + + if (value == null || x < value || float.IsNaN(x.GetValueOrDefault())) + { + value = x; + } + } + + return value; + } + + /// Returns the minimum value in a sequence of nullable values. + /// A sequence of nullable values to determine the minimum value of. + /// The to monitor for cancellation requests. The default is . + /// The minimum value in the sequence. + private static async ValueTask MinAsync( + IAsyncEnumerable source, + CancellationToken cancellationToken) + { + double? value = null; + await foreach (double? x in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (x is null) + { + continue; + } + + if (value == null || x < value || double.IsNaN(x.GetValueOrDefault())) + { + value = x; + } + } + + return value; + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MinByAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MinByAsync.cs new file mode 100644 index 00000000000000..e2354543570b3b --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/MinByAsync.cs @@ -0,0 +1,244 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns the minimum value in a generic sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to compare elements by. + /// A sequence of values to determine the minimum value of. + /// A function to extract the key for each element. + /// The to compare keys. + /// The to monitor for cancellation requests. The default is . + /// The value with the minimum key in the sequence. + /// is . + /// No key extracted from implements the or interface. + /// + /// If is a reference type and the source sequence is empty or contains only values that are , this method returns . + /// + public static ValueTask MinByAsync( + this IAsyncEnumerable source, + Func keySelector, + IComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer ?? Comparer.Default, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func keySelector, + IComparer comparer, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + if (default(TSource) is not null) + { + ThrowHelper.ThrowNoElementsException(); + } + + return default; + } + + TSource value = e.Current; + TKey key = keySelector(value); + + if (default(TKey) is null) + { + if (key is null) + { + TSource firstValue = value; + + do + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + // All keys are null, surface the first element. + return firstValue; + } + + value = e.Current; + key = keySelector(value); + } + while (key is null); + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (nextKey is not null && comparer.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + if (comparer == Comparer.Default) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (Comparer.Default.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (comparer.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the minimum value in a generic sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to compare elements by. + /// A sequence of values to determine the minimum value of. + /// A function to extract the key for each element. + /// The to compare keys. + /// The to monitor for cancellation requests. The default is . + /// The value with the minimum key in the sequence. + /// is . + /// No key extracted from implements the or interface. + /// + /// If is a reference type and the source sequence is empty or contains only values that are , this method returns . + /// + public static ValueTask MinByAsync( + this IAsyncEnumerable source, + Func> keySelector, + IComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer ?? Comparer.Default, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> keySelector, + IComparer comparer, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + if (default(TSource) is not null) + { + ThrowHelper.ThrowNoElementsException(); + } + + return default; + } + + TSource value = e.Current; + TKey key = await keySelector(value, cancellationToken).ConfigureAwait(false); + + if (default(TKey) is null) + { + if (key is null) + { + TSource firstValue = value; + + do + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + // All keys are null, surface the first element. + return firstValue; + } + + value = e.Current; + key = await keySelector(value, cancellationToken).ConfigureAwait(false); + } + while (key is null); + } + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = await keySelector(nextValue, cancellationToken).ConfigureAwait(false); + if (nextKey is not null && comparer.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + if (comparer == Comparer.Default) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = await keySelector(nextValue, cancellationToken).ConfigureAwait(false); + if (Comparer.Default.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource nextValue = e.Current; + TKey nextKey = await keySelector(nextValue, cancellationToken).ConfigureAwait(false); + if (comparer.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + } + + return value; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/OfType.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/OfType.cs new file mode 100644 index 00000000000000..2aa02ebb54858b --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/OfType.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + // TODO https://github.com/dotnet/runtime/issues/111717: + // Consider before shipping .NET 10 whether this can instead use extension everything to support any IAsyncEnumerable source. + // Right now it's limited because you can't cast an IAsyncEnumerable to IAsyncEnumerable, but this keeps it in + // sync with Cast, which needs its shape in support of query comprehensions. + + /// + /// Filters the elements of a based on a specified type . + /// + /// The type to filter the elements of the sequence on. + /// The whose elements to filter. + /// An that contains elements from the input sequence of type . + public static IAsyncEnumerable OfType( + this IAsyncEnumerable source) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (object? item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (item is TResult target) + { + yield return target; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/OrderBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/OrderBy.cs new file mode 100644 index 00000000000000..64414b2e11bc7b --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/OrderBy.cs @@ -0,0 +1,428 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Sorts the elements of a sequence in ascending order. + /// The type of the elements of . + /// A sequence of values to order. + /// An to compare keys. + /// An whose elements are sorted. + /// is . + public static IOrderedAsyncEnumerable Order( + this IAsyncEnumerable source, + IComparer? comparer = null) => + OrderBy(source, EnumerableSorter.IdentityFunc, comparer); + + /// Sorts the elements of a sequence in ascending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from an element. + /// An to compare keys. + /// An whose elements are sorted according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable OrderBy( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func keySelector, + IComparer? comparer = null) => + new OrderedIterator(source, keySelector, comparer, false, null); + + /// Sorts the elements of a sequence in ascending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from an element. + /// An to compare keys. + /// An whose elements are sorted according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable OrderBy( + this IAsyncEnumerable source, + Func> keySelector, + IComparer? comparer = null) => + new OrderedIterator(source, keySelector, comparer, false, null); + + /// Sorts the elements of a sequence in descending order. + /// The type of the elements of . + /// A sequence of values to order. + /// An to compare keys. + /// An whose elements are sorted in descending order. + /// is . + public static IOrderedAsyncEnumerable OrderDescending( + this IAsyncEnumerable source, + IComparer? comparer = null) => + OrderByDescending(source, EnumerableSorter.IdentityFunc, comparer); + + /// Sorts the elements of a sequence in descending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from an element. + /// An to compare keys. + /// An whose elements are sorted in descending order according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable OrderByDescending( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func keySelector, + IComparer? comparer = null) => + new OrderedIterator(source, keySelector, comparer, true, null); + + /// Sorts the elements of a sequence in descending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from an element. + /// An to compare keys. + /// An whose elements are sorted in descending order according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable OrderByDescending( + this IAsyncEnumerable source, + Func> keySelector, + IComparer? comparer = null) => + new OrderedIterator(source, keySelector, comparer, true, null); + + /// Performs a subsequent ordering of the elements in a sequence in ascending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from each element. + /// An to compare keys. + /// An whose elements are sorted according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable ThenBy( // satisfies the C# query-expression pattern + this IOrderedAsyncEnumerable source, + Func keySelector, + IComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + + return source.CreateOrderedAsyncEnumerable(keySelector, comparer, descending: false); + } + + /// Performs a subsequent ordering of the elements in a sequence in ascending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from each element. + /// An to compare keys. + /// An whose elements are sorted according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable ThenBy( + this IOrderedAsyncEnumerable source, + Func> keySelector, + IComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + + return source.CreateOrderedAsyncEnumerable(keySelector, comparer, descending: false); + } + + /// Performs a subsequent ordering of the elements in a sequence in descending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from each element. + /// An to compare keys. + /// An whose elements are sorted in descending order according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable ThenByDescending( // satisfies the C# query-expression pattern + this IOrderedAsyncEnumerable source, + Func keySelector, + IComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + + return source.CreateOrderedAsyncEnumerable(keySelector, comparer, descending: true); + } + + /// Performs a subsequent ordering of the elements in a sequence in descending order. + /// The type of the elements of . + /// The type of the key returned by . + /// A sequence of values to order. + /// A function to extract a key from each element. + /// An to compare keys. + /// An whose elements are sorted in descending order according to a key. + /// is . + /// is . + public static IOrderedAsyncEnumerable ThenByDescending( + this IOrderedAsyncEnumerable source, + Func> keySelector, + IComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(source); + + return source.CreateOrderedAsyncEnumerable(keySelector, comparer, descending: true); + } + + private abstract partial class OrderedIterator : IOrderedAsyncEnumerable + { + internal readonly IAsyncEnumerable _source; + + protected OrderedIterator(IAsyncEnumerable source) => _source = source; + + private protected ValueTask CreateSortedMapAsync(TElement[] buffer, CancellationToken cancellationToken) => + GetEnumerableSorter().SortAsync(buffer, buffer.Length, cancellationToken); + + internal abstract EnumerableSorter GetEnumerableSorter(EnumerableSorter? next = null); + + public IOrderedAsyncEnumerable CreateOrderedAsyncEnumerable(Func keySelector, IComparer? comparer, bool descending) => + new OrderedIterator(_source, keySelector, comparer, @descending, this); + + public IOrderedAsyncEnumerable CreateOrderedAsyncEnumerable(Func> keySelector, IComparer? comparer, bool descending) => + new OrderedIterator(_source, keySelector, comparer, @descending, this); + + public abstract IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken); + } + + private sealed partial class OrderedIterator : OrderedIterator + { + private readonly OrderedIterator? _parent; + private readonly object _keySelector; + private readonly IComparer _comparer; + private readonly bool _descending; + + internal OrderedIterator(IAsyncEnumerable source, object keySelector, IComparer? comparer, bool descending, OrderedIterator? parent) : + base(source) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + Debug.Assert(keySelector is Func or Func>); + + _parent = parent; + _keySelector = keySelector; + _comparer = comparer ?? Comparer.Default; + _descending = descending; + } + + internal override EnumerableSorter GetEnumerableSorter(EnumerableSorter? next) + { + // Special case the common use of string with default comparer. Comparer.Default checks the + // thread's Culture on each call which is an overhead which is not required, because we are about to + // do a sort which remains on the current thread (and EnumerableSorter is not used afterwards). + IComparer comparer = _comparer; + if (typeof(TKey) == typeof(string) && comparer == Comparer.Default) + { + comparer = (IComparer)StringComparer.CurrentCulture; + } + + EnumerableSorter sorter = new EnumerableSorter(_keySelector, comparer, _descending, next); + if (_parent is not null) + { + sorter = _parent.GetEnumerableSorter(sorter); + } + + return sorter; + } + + public override async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken) + { + TElement[] buffer = await _source.ToArrayAsync(cancellationToken).ConfigureAwait(false); + if (buffer.Length > 0) + { + int[] map = await CreateSortedMapAsync(buffer, cancellationToken).ConfigureAwait(false); + for (int i = 0; i < map.Length; i++) + { + yield return buffer[map[i]]; + } + } + } + } + + private abstract class EnumerableSorter : IComparer + { + /// Function that returns its input unmodified. + /// + /// Used for reference equality in order to avoid unnecessary computation when a caller + /// can benefit from knowing that the produced value is identical to the input. + /// + internal static readonly Func IdentityFunc = e => e; + + internal abstract Task ComputeKeysAsync(TElement[] elements, int count, CancellationToken cancellationToken); + + public abstract int Compare(int index1, int index2); + + internal async ValueTask SortAsync(TElement[] elements, int count, CancellationToken cancellationToken) + { + await ComputeKeysAsync(elements, count, cancellationToken).ConfigureAwait(false); + + int[] map = new int[count]; + for (int i = 0; i < map.Length; i++) + { + map[i] = i; + } + + QuickSort(map, 0, count - 1); + + return map; + } + + protected abstract void QuickSort(int[] map, int left, int right); + } + + private sealed class EnumerableSorter : EnumerableSorter, IComparer + { + private readonly object _keySelector; + private readonly IComparer _comparer; + private readonly bool _descending; + private readonly EnumerableSorter? _next; + private TKey[]? _keys; + + internal EnumerableSorter(object keySelector, IComparer comparer, bool descending, EnumerableSorter? next) + { + _keySelector = keySelector; + _comparer = comparer; + _descending = descending; + _next = next; + } + + internal override async Task ComputeKeysAsync(TElement[] elements, int count, CancellationToken cancellationToken) + { + object keySelector = _keySelector; + if (ReferenceEquals(keySelector, IdentityFunc)) + { + // The key selector is our known identity function, which means we don't + // need to invoke the key selector for every element. Further, we can just + // use the original array as the keys (even if count is smaller, as the additional + // values will just be ignored). + Debug.Assert(typeof(TKey) == typeof(TElement)); + _keys = (TKey[])(object)elements; + } + else + { + var keys = new TKey[count]; + if (keySelector is Func syncSelector) + { + for (int i = 0; i < keys.Length; i++) + { + keys[i] = syncSelector(elements[i]); + } + } + else + { + var asyncSelector = (Func>)keySelector; + for (int i = 0; i < keys.Length; i++) + { + keys[i] = await asyncSelector(elements[i], cancellationToken).ConfigureAwait(false); + } + } + _keys = keys; + } + + _next?.ComputeKeysAsync(elements, count, cancellationToken); + } + + public override int Compare(int index1, int index2) + { + TKey[]? keys = _keys; + Debug.Assert(keys is not null); + + int c = _comparer.Compare(keys[index1], keys[index2]); + if (c == 0) + { + if (_next is null) + { + return index1 - index2; // ensure stability of sort + } + + return _next.Compare(index1, index2); + } + + // -c will result in a negative value for int.MinValue (-int.MinValue == int.MinValue). + // Flipping keys earlier is more likely to trigger something strange in a comparer, + // particularly as it comes to the sort being stable. + return (_descending != (c > 0)) ? 1 : -1; + } + + protected override void QuickSort(int[] keys, int lo, int hi) + { +#if NET + if (typeof(TKey).IsValueType && _next is null && _comparer == Comparer.Default) + { + // We can use Comparer.Default.Compare and benefit from devirtualization and inlining. + // We can also avoid extra steps to check whether we need to deal with a subsequent tie breaker (_next). + new Span(keys, lo, hi - lo + 1).Sort(!_descending ? + Compare_DefaultComparer_NoNext_Ascending : + Compare_DefaultComparer_NoNext_Descending); + + int Compare_DefaultComparer_NoNext_Ascending(int index1, int index2) + { + Debug.Assert(typeof(TKey).IsValueType); + Debug.Assert(_comparer == Comparer.Default); + Debug.Assert(_next is null); + Debug.Assert(!_descending); + + TKey[]? keys = _keys; + Debug.Assert(keys is not null); + + int c = Comparer.Default.Compare(keys[index1], keys[index2]); + return + c == 0 ? index1 - index2 : // ensure stability of sort + c; + } + + int Compare_DefaultComparer_NoNext_Descending(int index1, int index2) + { + Debug.Assert(typeof(TKey).IsValueType); + Debug.Assert(_comparer == Comparer.Default); + Debug.Assert(_next is null); + Debug.Assert(_descending); + + TKey[]? keys = _keys; + Debug.Assert(keys is not null); + + int c = Comparer.Default.Compare(keys[index2], keys[index1]); + return + c == 0 ? index1 - index2 : // ensure stability of sort + c; + } + } + else +#endif + { +#if NET + new Span(keys, lo, hi - lo + 1).Sort(Compare); +#else + Array.Sort(keys, lo, hi - lo + 1, this); +#endif + } + } + } + } + + /// Represents a sorted asynchronous sequence. + /// The type of the elements of the sequence. + /// This interface is not intended to be implemented by user code. It supports the .NET infrastructure. + public interface IOrderedAsyncEnumerable : IAsyncEnumerable + { + /// Performs a subsequent ordering on the elements of an according to a key. + /// The type of the key produced by . + /// The function used to extract the key for each element. + /// The used to compare keys for placement in the returned sequence. + /// true to sort the elements in descending order; false to sort the elements in ascending order. + /// An whose elements are sorted according to a key. + IOrderedAsyncEnumerable CreateOrderedAsyncEnumerable(Func keySelector, IComparer? comparer, bool descending); + + /// Performs a subsequent ordering on the elements of an according to a key. + /// The type of the key produced by . + /// The function used to extract the key for each element. + /// The used to compare keys for placement in the returned sequence. + /// true to sort the elements in descending order; false to sort the elements in ascending order. + /// An whose elements are sorted according to a key. + IOrderedAsyncEnumerable CreateOrderedAsyncEnumerable(Func> keySelector, IComparer? comparer, bool descending); + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Prepend.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Prepend.cs new file mode 100644 index 00000000000000..ffaa302a9882de --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Prepend.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Adds a value to the beginning of the sequence. + /// The type of the elements of source. + /// A sequence of values. + /// The value to prepend to source. + /// A new sequence that begins with element. + /// is . + public static IAsyncEnumerable Prepend( + this IAsyncEnumerable source, + TSource element) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, element, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + TSource element, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + yield return element; + + await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return item; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Range.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Range.cs new file mode 100644 index 00000000000000..315dca71946459 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Range.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Generates a sequence of integral numbers within a specified range. + /// The value of the first integer in the sequence. + /// The number of sequential integers to generate. + /// An that contains a range of sequential integral numbers. + /// is less than 0 + /// + -1 is larger than . + public static IAsyncEnumerable Range(int start, int count) + { + if (count == 0) + { + return Empty(); + } + + if (count < 0 || (((long)start) + count - 1) > int.MaxValue) + { + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count)); + } + + return Impl(start, count); + + static async IAsyncEnumerable Impl(int start, int count) + { + for (int i = 0; i < count; i++) + { + yield return start + i; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Repeat.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Repeat.cs new file mode 100644 index 00000000000000..25350b7dfd3d8a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Repeat.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Generates a sequence that contains one repeated value. + /// The type of the value to be repeated in the result sequence. + /// The value to be repeated. + /// The number of times to repeat the value in the generated sequence. + /// An that contains a repeated value. + /// is less than 0. + public static IAsyncEnumerable Repeat(TResult element, int count) + { + if (count == 0) + { + return Empty(); + } + + ThrowHelper.ThrowIfNegative(count); + + return Impl(element, count); + + static async IAsyncEnumerable Impl(TResult element, int count) + { + while (count-- != 0) + { + yield return element; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Reverse.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Reverse.cs new file mode 100644 index 00000000000000..4ecc083ca54a13 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Reverse.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Inverts the order of the elements in a sequence. + /// The type of the elements of source. + /// A sequence of values to reverse. + /// A sequence whose elements correspond to those of the input sequence in reverse order. + /// is . + public static IAsyncEnumerable Reverse( + this IAsyncEnumerable source) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + TSource[] array = await source.ToArrayAsync(cancellationToken).ConfigureAwait(false); + for (int i = array.Length - 1; i >= 0; i--) + { + yield return array[i]; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/RightJoin.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/RightJoin.cs new file mode 100644 index 00000000000000..810496d53ea793 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/RightJoin.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Correlates the elements of two sequences based on matching keys. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// A function to create a result element from two matching elements. + /// An to use to hash and compare keys. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The type of the result elements. + /// An that has elements of type that are obtained by performing a right outer join on two sequences. + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable RightJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func outerKeySelector, + Func innerKeySelector, + Func resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = inner.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup outerLookup = await AsyncLookup.CreateForJoinAsync(outer, outerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + do + { + TInner item = e.Current; + Grouping? g = outerLookup.GetGrouping(innerKeySelector(item), create: false); + if (g is null) + { + yield return resultSelector(default, item); + } + else + { + int count = g._count; + TOuter[] elements = g._elements; + for (int i = 0; i != count; ++i) + { + yield return resultSelector(elements[i], item); + } + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Correlates the elements of two sequences based on matching keys. + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A function to extract the join key from each element of the first sequence. + /// A function to extract the join key from each element of the second sequence. + /// A function to create a result element from two matching elements. + /// An to use to hash and compare keys. + /// The type of the elements of the first sequence. + /// The type of the elements of the second sequence. + /// The type of the keys returned by the key selector functions. + /// The type of the result elements. + /// An that has elements of type that are obtained by performing a right outer join on two sequences. + /// is . + /// is . + /// is . + /// is . + /// is . + public static IAsyncEnumerable RightJoin( + this IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func> resultSelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(outer); + ThrowHelper.ThrowIfNull(inner); + ThrowHelper.ThrowIfNull(outerKeySelector); + ThrowHelper.ThrowIfNull(innerKeySelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable outer, + IAsyncEnumerable inner, + Func> outerKeySelector, + Func> innerKeySelector, + Func> resultSelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = inner.GetAsyncEnumerator(cancellationToken); + try + { + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + AsyncLookup outerLookup = await AsyncLookup.CreateForJoinAsync(outer, outerKeySelector, comparer, cancellationToken).ConfigureAwait(false); + do + { + TInner item = e.Current; + Grouping? g = outerLookup.GetGrouping(await innerKeySelector(item, cancellationToken).ConfigureAwait(false), create: false); + if (g is null) + { + yield return await resultSelector(default, item, cancellationToken).ConfigureAwait(false); + } + else + { + int count = g._count; + TOuter[] elements = g._elements; + for (int i = 0; i != count; ++i) + { + yield return await resultSelector(elements[i], item, cancellationToken).ConfigureAwait(false); + } + } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Select.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Select.cs new file mode 100644 index 00000000000000..a366e55111b158 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Select.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Projects each element of a sequence into a new form. + /// The type of the elements of source. + /// The type of the value returned by selector. + /// A sequence of values to invoke a transform function on. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the transform function on each element of source. + /// + /// is . + /// is . + public static IAsyncEnumerable Select( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return selector(element); + } + } + } + + /// Projects each element of a sequence into a new form. + /// The type of the elements of source. + /// The type of the value returned by selector. + /// A sequence of values to invoke a transform function on. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the transform function on each element of source. + /// + /// is . + /// is . + public static IAsyncEnumerable Select( + this IAsyncEnumerable source, + Func> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return await selector(element, cancellationToken).ConfigureAwait(false); + } + } + } + + /// Projects each element of a sequence into a new form by incorporating the element's index. + /// The type of the elements of source. + /// The type of the value returned by selector. + /// A sequence of values to invoke a transform function on. + /// + /// A transform function to apply to each element; the second parameter of + /// the function represents the index of the source element. + /// + /// + /// An whose elements are the result of + /// invoking the transform function on each element of source. + /// + /// is . + /// is . + public static IAsyncEnumerable Select( + this IAsyncEnumerable source, + Func selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return selector(element, checked(++index)); + } + } + } + + /// Projects each element of a sequence into a new form by incorporating the element's index. + /// The type of the elements of source. + /// The type of the value returned by selector. + /// A sequence of values to invoke a transform function on. + /// + /// A transform function to apply to each element; the second parameter of + /// the function represents the index of the source element. + /// + /// + /// An whose elements are the result of + /// invoking the transform function on each element of source. + /// + /// is . + /// is . + public static IAsyncEnumerable Select( + this IAsyncEnumerable source, + Func> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return await selector(element, checked(++index), cancellationToken).ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SelectMany.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SelectMany.cs new file mode 100644 index 00000000000000..3789754f2a994e --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SelectMany.cs @@ -0,0 +1,583 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Projects each element of a sequence to an and + /// flattens the resulting sequences into one sequence. + /// + /// The type of the elements of source. + /// The type of the elements of the sequence returned by selector. + /// A sequence of values to project. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element of the input sequence. + /// + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TResult subElement in selector(element)) + { + yield return subElement; + } + } + } + } + + /// + /// Projects each element of a sequence to an and + /// flattens the resulting sequences into one sequence. + /// + /// The type of the elements of source. + /// The type of the elements of the sequence returned by selector. + /// A sequence of values to project. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element of the input sequence. + /// + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func>> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func>> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TResult subElement in await selector(element, cancellationToken).ConfigureAwait(false)) + { + yield return subElement; + } + } + } + } + + /// + /// Projects each element of a sequence to an and + /// flattens the resulting sequences into one sequence. + /// + /// The type of the elements of source. + /// The type of the elements of the sequence returned by selector. + /// A sequence of values to project. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element of the input sequence. + /// + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + await foreach (TResult subElement in selector(element).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return subElement; + } + } + } + } + + /// + /// Projects each element of a sequence to an and + /// flattens the resulting sequences into one sequence. + /// The index of each source element is used in the projected form of that element. + /// + /// The type of the elements of source. + /// The type of the elements of the sequence returned by selector. + /// A sequence of values to project. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element of the input sequence. + /// + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TResult subElement in selector(element, checked(++index))) + { + yield return subElement; + } + } + } + } + + /// + /// Projects each element of a sequence to an and + /// flattens the resulting sequences into one sequence. + /// The index of each source element is used in the projected form of that element. + /// + /// The type of the elements of source. + /// The type of the elements of the sequence returned by selector. + /// A sequence of values to project. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element of the input sequence. + /// + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func>> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func>> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TResult subElement in await selector(element, checked(++index), cancellationToken).ConfigureAwait(false)) + { + yield return subElement; + } + } + } + } + + /// + /// Projects each element of a sequence to an and + /// flattens the resulting sequences into one sequence. + /// The index of each source element is used in the projected form of that element. + /// + /// The type of the elements of source. + /// The type of the elements of the sequence returned by selector. + /// A sequence of values to project. + /// A transform function to apply to each element. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element of the input sequence. + /// + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> selector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(selector); + + return Impl(source, selector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> selector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + await foreach (TResult subElement in selector(element, checked(++index)).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return subElement; + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. The index of each source element is used in + /// the intermediate projected form of that element. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> collectionSelector, + Func resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> collectionSelector, + Func resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TCollection subElement in collectionSelector(element)) + { + yield return resultSelector(element, subElement); + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. The index of each source element is used in + /// the intermediate projected form of that element. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func>> collectionSelector, + Func> resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func>> collectionSelector, + Func> resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TCollection subElement in await collectionSelector(element, cancellationToken).ConfigureAwait(false)) + { + yield return await resultSelector(element, subElement, cancellationToken).ConfigureAwait(false); + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. The index of each source element is used in + /// the intermediate projected form of that element. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func> collectionSelector, + Func resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> collectionSelector, + Func resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + await foreach (TCollection subElement in collectionSelector(element).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return resultSelector(element, subElement); + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. The index of each source element is used in + /// the intermediate projected form of that element. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> collectionSelector, + Func> resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + async static IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> collectionSelector, + Func> resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + await foreach (TCollection subElement in collectionSelector(element).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return await resultSelector(element, subElement, cancellationToken).ConfigureAwait(false); + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> collectionSelector, + Func resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> collectionSelector, + Func resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TCollection subElement in collectionSelector(element, checked(++index))) + { + yield return resultSelector(element, subElement); + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func>> collectionSelector, + Func> resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func>> collectionSelector, + Func> resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + foreach (TCollection subElement in await collectionSelector(element, checked(++index), cancellationToken).ConfigureAwait(false)) + { + yield return await resultSelector(element, subElement, cancellationToken).ConfigureAwait(false); + } + } + } + } + + /// + /// Projects each element of a sequence to an , + /// flattens the resulting sequences into one sequence, + /// and invokes a result selector function on each element therein. + /// + /// The type of the elements of source. + /// The type of the intermediate elements collected by . + /// The type of the elements of the resulting sequence. + /// A sequence of values to project. + /// A transform function to apply to each element of the input sequence. + /// A transform function to apply to each element of the intermediate sequence. + /// + /// An whose elements are the result of + /// invoking the one-to-many transform function on each element + /// of source and then mapping each of those sequence elements and their corresponding + /// source element to a result element. + /// + /// is . + /// is . + /// is . + public static IAsyncEnumerable SelectMany( + this IAsyncEnumerable source, + Func> collectionSelector, + Func> resultSelector) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(collectionSelector); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(source, collectionSelector, resultSelector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> collectionSelector, + Func> resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + await foreach (TCollection subElement in collectionSelector(element, checked(++index)).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return await resultSelector(element, subElement, cancellationToken).ConfigureAwait(false); + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SequenceEqualAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SequenceEqualAsync.cs new file mode 100644 index 00000000000000..9deabee066b73f --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SequenceEqualAsync.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Determines whether two sequences are equal by comparing their elements. + /// + /// An to compare to . + /// An to compare to the first sequence. + /// An to use to compare elements. + /// The to monitor for cancellation requests. The default is . + /// + /// true if the two source sequences are of equal length and their corresponding + /// elements compare equal according to comparer; otherwise, false. + /// + public static ValueTask SequenceEqualAsync( + this IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + + return Impl(first, second, comparer ?? EqualityComparer.Default, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer comparer, + CancellationToken cancellationToken) + { + IAsyncEnumerator e1 = first.GetAsyncEnumerator(cancellationToken); + try + { + IAsyncEnumerator e2 = second.GetAsyncEnumerator(cancellationToken); + try + { + while (await e1.MoveNextAsync().ConfigureAwait(false)) + { + if (!await e2.MoveNextAsync().ConfigureAwait(false) || !comparer.Equals(e1.Current, e2.Current)) + { + return false; + } + } + + return !await e2.MoveNextAsync().ConfigureAwait(false); + } + finally + { + await e2.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + await e1.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SingleAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SingleAsync.cs new file mode 100644 index 00000000000000..d433ee68788ea9 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SingleAsync.cs @@ -0,0 +1,394 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Returns the only element of a sequence, and throws an exception if there is not + /// exactly one element in the sequence. + /// + /// The type of the elements of source. + /// An to return the single element of. + /// The to monitor for cancellation requests. The default is . + /// The single element of the input sequence. + /// is . + /// The sequence is empty (via the returned task). + /// The sequence contains more than one element. (via the returned task). + public static ValueTask SingleAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowNoElementsException(); + } + + TSource result = e.Current; + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowMoreThanOneElementException(); + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Returns the only element of a sequence that satisfies a specified condition, + /// and throws an exception if more than one such element exists. + /// + /// The type of the elements of source. + /// An to return the single element of. + /// A function to test an element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The single element of the input sequence that satisfies a condition. + /// is . + /// is . + /// The sequence is empty (via the returned task). + /// No element satisfies the condition in (via the returned task). + /// More than one element satisfies the condition in (via the returned task). + public static ValueTask SingleAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource result = e.Current; + if (predicate(result)) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (predicate(e.Current)) + { + ThrowHelper.ThrowMoreThanOneMatchException(); + } + } + + return result; + } + } + + ThrowHelper.ThrowNoElementsException(); + return default!; // Unreachable + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Returns the only element of a sequence that satisfies a specified condition, + /// and throws an exception if more than one such element exists. + /// + /// The type of the elements of source. + /// An to return the single element of. + /// A function to test an element for a condition. + /// The to monitor for cancellation requests. The default is . + /// The single element of the input sequence that satisfies a condition. + /// is . + /// is . + /// The sequence is empty (via the returned task). + /// No element satisfies the condition in (via the returned task). + /// More than one element satisfies the condition in (via the returned task). + public static ValueTask SingleAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource result = e.Current; + if (await predicate(result, cancellationToken).ConfigureAwait(false)) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (await predicate(e.Current, cancellationToken).ConfigureAwait(false)) + { + ThrowHelper.ThrowMoreThanOneMatchException(); + } + } + + return result; + } + } + + ThrowHelper.ThrowNoElementsException(); + return default!; // Unreachable + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Returns the only element of a sequence, or a default value if the sequence is + /// empty; this method throws an exception if there is more than one element in the sequence. + /// + /// The type of the elements of source. + /// An to return the single element of. + /// The to monitor for cancellation requests. The default is . + /// + /// The single element of the input sequence, or the default value of + /// if the sequence contains no elements. + /// + /// is . + /// The sequence contains more than one element. (via the returned task). + public static ValueTask SingleOrDefaultAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) => + SingleOrDefaultAsync(source, default(TSource), cancellationToken); + + /// Returns the only element of a sequence, or a default value if the sequence is empty; this method throws an exception if there is more than one element in the sequence. + /// The type of the elements of . + /// An to return the single element of. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// The single element of the input sequence, or if the sequence contains no elements. + /// is . + /// The input sequence contains more than one element. + public static ValueTask SingleOrDefaultAsync( + this IAsyncEnumerable source, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + TSource defaultValue, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + return defaultValue; + } + + TSource result = e.Current; + if (await e.MoveNextAsync().ConfigureAwait(false)) + { + ThrowHelper.ThrowMoreThanOneElementException(); + } + + return result; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Returns the only element of a sequence that satisfies a specified condition or + /// a default value if no such element exists; this method throws an exception if + /// more than one element satisfies the condition. + /// + /// The type of the elements of . + /// An to return the single element of. + /// A function to test an element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// The single element of the input sequence that satisfies the condition, or the default value of + /// if no such element is found. + /// + /// is . + /// is . + /// The input sequence contains more than one element. + public static ValueTask SingleOrDefaultAsync( + this IAsyncEnumerable source, + Func predicate, + CancellationToken cancellationToken = default) => + SingleOrDefaultAsync(source, predicate!, default, cancellationToken); + + /// + /// Returns the only element of a sequence that satisfies a specified condition or + /// a default value if no such element exists; this method throws an exception if + /// more than one element satisfies the condition. + /// + /// The type of the elements of . + /// An to return the single element of. + /// A function to test an element for a condition. + /// The to monitor for cancellation requests. The default is . + /// + /// The single element of the input sequence that satisfies the condition, or the default value of + /// if no such element is found. + /// + /// is . + /// is . + /// The input sequence contains more than one element. + public static ValueTask SingleOrDefaultAsync( + this IAsyncEnumerable source, + Func> predicate, + CancellationToken cancellationToken = default) => + SingleOrDefaultAsync(source, predicate!, default, cancellationToken); + + /// Returns the only element of a sequence that satisfies a specified condition or a default value if no such element exists; this method throws an exception if more than one element satisfies the condition. + /// The type of the elements of . + /// An to return a single element from. + /// A function to test an element for a condition. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// The single element of the input sequence that satisfies the condition, or if no such element is found. + /// is . + /// is . + /// More than one element satisfies the condition in . + public static ValueTask SingleOrDefaultAsync( + this IAsyncEnumerable source, + Func predicate, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func predicate, + TSource defaultValue, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource result = e.Current; + if (predicate(result)) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (predicate(e.Current)) + { + ThrowHelper.ThrowMoreThanOneMatchException(); + } + } + + return result; + } + } + + return defaultValue; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Returns the only element of a sequence that satisfies a specified condition or a default value if no such element exists; this method throws an exception if more than one element satisfies the condition. + /// The type of the elements of . + /// An to return a single element from. + /// A function to test an element for a condition. + /// The default value to return if the sequence is empty. + /// The to monitor for cancellation requests. The default is . + /// The single element of the input sequence that satisfies the condition, or if no such element is found. + /// is . + /// is . + /// More than one element satisfies the condition in . + public static ValueTask SingleOrDefaultAsync( + this IAsyncEnumerable source, + Func> predicate, + TSource defaultValue, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, defaultValue, cancellationToken); + + static async ValueTask Impl( + IAsyncEnumerable source, + Func> predicate, + TSource defaultValue, + CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource result = e.Current; + if (await predicate(result, cancellationToken).ConfigureAwait(false)) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (await predicate(e.Current, cancellationToken).ConfigureAwait(false)) + { + ThrowHelper.ThrowMoreThanOneMatchException(); + } + } + + return result; + } + } + + return defaultValue; + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Skip.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Skip.cs new file mode 100644 index 00000000000000..764f9511c4e644 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Skip.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Bypasses a specified number of elements in a sequence and then returns the remaining elements. + /// The type of the elements of source. + /// An to return elements from. + /// The number of elements to skip before returning the remaining elements. + /// An that contains the elements that occur after the specified index in the input sequence. + /// is . + public static IAsyncEnumerable Skip( + this IAsyncEnumerable source, + int count) + { + ThrowHelper.ThrowIfNull(source); + + return count <= 0 ? + source : + Impl(source, count, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + int count, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (count > 0 && await e.MoveNextAsync().ConfigureAwait(false)) + { + count--; + } + + if (count <= 0) + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + yield return e.Current; + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SkipLast.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SkipLast.cs new file mode 100644 index 00000000000000..9856d9d7f56a8b --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SkipLast.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Returns a new sequence that contains the elements from + /// with the last elements of the source collection omitted. + /// + /// The type of the elements of . + /// An to return elements from. + /// The number of elements to omit from the end of the sequence. + /// + /// A new sequence that contains the elements from minus + /// elements from the end of the sequence. + /// + /// is . + public static IAsyncEnumerable SkipLast( + this IAsyncEnumerable source, + int count) + { + ThrowHelper.ThrowIfNull(source); + + return + count <= 0 ? source : + TakeRangeFromEndIterator(source, isStartIndexFromEnd: false, startIndex: 0, isEndIndexFromEnd: true, endIndex: count, default); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SkipWhile.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SkipWhile.cs new file mode 100644 index 00000000000000..976508a944ff05 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SkipWhile.cs @@ -0,0 +1,235 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Bypasses elements in a sequence as long as a specified condition is true and + /// then returns the remaining elements. + /// + /// The type of the elements of source. + /// An to return elements from. + /// A function to test each element for a condition. + /// + /// An that contains the elements from the + /// input sequence starting at the first element in the linear series that does not + /// pass the test specified by predicate. + /// + /// is . + /// is . + public static IAsyncEnumerable SkipWhile( + this IAsyncEnumerable source, + Func predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (!predicate(element)) + { + yield return element; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + yield return e.Current; + } + + yield break; + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Bypasses elements in a sequence as long as a specified condition is true and + /// then returns the remaining elements. + /// + /// The type of the elements of source. + /// An to return elements from. + /// A function to test each element for a condition. + /// + /// An that contains the elements from the + /// input sequence starting at the first element in the linear series that does not + /// pass the test specified by predicate. + /// + /// is . + /// is . + public static IAsyncEnumerable SkipWhile( + this IAsyncEnumerable source, + Func> predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (!await predicate(element, cancellationToken).ConfigureAwait(false)) + { + yield return element; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + yield return e.Current; + } + + yield break; + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Bypasses elements in a sequence as long as a specified condition is true and + /// then returns the remaining elements. The element's index is used in the logic + /// of the predicate function. + /// + /// The type of the elements of source. + /// An to return elements from. + /// + /// A function to test each element for a condition; the second parameter + /// of the function represents the index of the source element. + /// + /// + /// An that contains the elements from the + /// input sequence starting at the first element in the linear series that does not + /// pass the test specified by predicate. + /// + /// is . + /// is . + public static IAsyncEnumerable SkipWhile( + this IAsyncEnumerable source, + Func predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + int index = -1; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (!predicate(element, checked(++index))) + { + yield return element; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + yield return e.Current; + } + + yield break; + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Bypasses elements in a sequence as long as a specified condition is true and + /// then returns the remaining elements. The element's index is used in the logic + /// of the predicate function. + /// + /// The type of the elements of source. + /// An to return elements from. + /// + /// A function to test each element for a condition; the second parameter + /// of the function represents the index of the source element. + /// + /// + /// An that contains the elements from the + /// input sequence starting at the first element in the linear series that does not + /// pass the test specified by predicate. + /// + /// is . + /// is . + public static IAsyncEnumerable SkipWhile( + this IAsyncEnumerable source, + Func> predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + int index = -1; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + TSource element = e.Current; + if (!await predicate(element, checked(++index), cancellationToken).ConfigureAwait(false)) + { + yield return element; + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + yield return e.Current; + } + + yield break; + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SumAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SumAsync.cs new file mode 100644 index 00000000000000..ab8643043e79f5 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/SumAsync.cs @@ -0,0 +1,282 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Computes the sum of a sequence of values. + /// A sequence of values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + /// The sum is larger than . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + int sum = 0; + await foreach (int item in source) + { + checked { sum += item; } + } + return sum; + } + } + + /// Computes the sum of a sequence of values. + /// A sequence of values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + /// The sum is larger than . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + long sum = 0; + await foreach (long item in source) + { + checked { sum += item; } + } + return sum; + } + } + + /// Computes the sum of a sequence of values. + /// A sequence of values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + await foreach (float item in source) + { + sum += item; + } + return (float)sum; + } + } + + /// Computes the sum of a sequence of values. + /// A sequence of values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + await foreach (double item in source) + { + sum += item; + } + return sum; + } + } + + /// Computes the sum of a sequence of values. + /// A sequence of values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + decimal sum = 0; + await foreach (decimal item in source) + { + sum += item; + } + return sum; + } + } + + /// Computes the sum of a sequence of nullable values. + /// A sequence of nullable values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + /// The sum is larger than . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + int sum = 0; + await foreach (int? item in source) + { + if (item is not null) + { + checked { sum += item.GetValueOrDefault(); } + } + } + return sum; + } + } + + /// Computes the sum of a sequence of nullable values. + /// A sequence of nullable values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + /// The sum is larger than . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + long sum = 0; + await foreach (long? item in source) + { + if (item is not null) + { + checked { sum += item.GetValueOrDefault(); } + } + } + return sum; + } + } + + /// Computes the sum of a sequence of nullable values. + /// A sequence of nullable values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + await foreach (float? item in source) + { + if (item is not null) + { + sum += item.GetValueOrDefault(); + } + } + return (float)sum; + } + } + + /// Computes the sum of a sequence of nullable values. + /// A sequence of nullable values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + double sum = 0; + await foreach (double? item in source) + { + if (item is not null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + } + + /// Computes the sum of a sequence of nullable values. + /// A sequence of nullable values to calculate the sum of. + /// The to monitor for cancellation requests. The default is . + /// The sum of the values in the sequence. + /// is . + public static ValueTask SumAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + decimal sum = 0; + await foreach (decimal? item in source) + { + if (item is not null) + { + sum += item.GetValueOrDefault(); + } + } + return sum; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Take.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Take.cs new file mode 100644 index 00000000000000..003df03d59177f --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Take.cs @@ -0,0 +1,241 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns a specified number of contiguous elements from the start of a sequence. + /// The type of the elements of source. + /// The sequence to return elements from. + /// The number of elements to return. + /// + /// An that contains the specified number + /// of elements from the start of the input sequence. + /// + /// is . + public static IAsyncEnumerable Take( + this IAsyncEnumerable source, + int count) + { + ThrowHelper.ThrowIfNull(source); + + return count <= 0 ? + Empty() : + Impl(source, count, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + int count, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return element; + + if (--count == 0) + { + break; + } + } + } + } + + /// Returns a specified range of contiguous elements from a sequence. + /// The type of the elements of . + /// The sequence to return elements from. + /// The range of elements to return, which has start and end indexes either from the start or the end. + /// An that contains the specified of elements from the sequence. + /// + /// This method is implemented by using deferred execution. The immediate return value is an object that stores all the information that is required to perform the action. The query represented by this method is not executed until the object is enumerated either by calling its `GetEnumerator` method directly or by using `foreach` in Visual C# or `For Each` in Visual Basic. + /// Take enumerates and yields elements whose indices belong to the specified . + /// + /// is . + public static IAsyncEnumerable Take( + this IAsyncEnumerable source, + Range range) + { + ThrowHelper.ThrowIfNull(source); + + Index start = range.Start, end = range.End; + bool isStartIndexFromEnd = start.IsFromEnd, isEndIndexFromEnd = end.IsFromEnd; + int startIndex = start.Value, endIndex = end.Value; + Debug.Assert(startIndex >= 0); + Debug.Assert(endIndex >= 0); + + if (isStartIndexFromEnd) + { + if (startIndex == 0 || (isEndIndexFromEnd && endIndex >= startIndex)) + { + return Empty(); + } + } + else if (!isEndIndexFromEnd) + { + return startIndex >= endIndex ? + Empty() : + Impl(source, startIndex, endIndex, default); + } + + return TakeRangeFromEndIterator(source, isStartIndexFromEnd, startIndex, isEndIndexFromEnd, endIndex, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, int startIndex, int endIndex, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + Debug.Assert(source is not null); + Debug.Assert(startIndex >= 0 && startIndex < endIndex); + + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + int index = 0; + while (index < startIndex && await e.MoveNextAsync().ConfigureAwait(false)) + { + ++index; + } + + if (index < startIndex) + { + yield break; + } + + while (index < endIndex && await e.MoveNextAsync().ConfigureAwait(false)) + { + yield return e.Current; + ++index; + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + } + + private static async IAsyncEnumerable TakeRangeFromEndIterator( + IAsyncEnumerable source, + bool isStartIndexFromEnd, + int startIndex, + bool isEndIndexFromEnd, + int endIndex, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + Debug.Assert(source is not null); + Debug.Assert(isStartIndexFromEnd || isEndIndexFromEnd); + Debug.Assert(isStartIndexFromEnd + ? startIndex > 0 && (!isEndIndexFromEnd || startIndex > endIndex) + : startIndex >= 0 && (isEndIndexFromEnd || startIndex < endIndex)); + + Queue queue; + int count; + + if (isStartIndexFromEnd) + { + // TakeLast compat: enumerator should be disposed before yielding the first element. + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + yield break; + } + + queue = new Queue(); + queue.Enqueue(e.Current); + count = 1; + + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (count < startIndex) + { + queue.Enqueue(e.Current); + ++count; + } + else + { + do + { + queue.Dequeue(); + queue.Enqueue(e.Current); + checked { ++count; } + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + + break; + } + } + + Debug.Assert(queue.Count == Math.Min(count, startIndex)); + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + + startIndex = CalculateStartIndexFromEnd(startIndex, count); + endIndex = CalculateEndIndex(isEndIndexFromEnd, endIndex, count); + Debug.Assert(endIndex - startIndex <= queue.Count); + + for (int rangeIndex = startIndex; rangeIndex < endIndex; rangeIndex++) + { + yield return queue.Dequeue(); + } + } + else + { + Debug.Assert(!isStartIndexFromEnd && isEndIndexFromEnd); + + // SkipLast compat: the enumerator should be disposed at the end of the enumeration. + IAsyncEnumerator e = source.GetAsyncEnumerator(cancellationToken); + try + { + count = 0; + while (count < startIndex && await e.MoveNextAsync().ConfigureAwait(false)) + { + ++count; + } + + if (count == startIndex) + { + queue = new Queue(); + while (await e.MoveNextAsync().ConfigureAwait(false)) + { + if (queue.Count == endIndex) + { + do + { + queue.Enqueue(e.Current); + yield return queue.Dequeue(); + } + while (await e.MoveNextAsync().ConfigureAwait(false)); + + break; + } + else + { + queue.Enqueue(e.Current); + } + } + } + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + static int CalculateStartIndexFromEnd(int startIndex, int count) => + Math.Max(0, count - startIndex); + + static int CalculateEndIndex(bool isEndIndexFromEnd, int endIndex, int count) => + Math.Min(count, isEndIndexFromEnd ? count - endIndex : endIndex); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/TakeLast.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/TakeLast.cs new file mode 100644 index 00000000000000..b3fd0df73684cc --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/TakeLast.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns a new sequence that contains the last elements from . + /// The type of the elements in . + /// A sequence to return elements from. + /// The number of elements to take from the end of the sequence. + /// A new sequence that contains the last elements from . + public static IAsyncEnumerable TakeLast( + this IAsyncEnumerable source, + int count) + { + ThrowHelper.ThrowIfNull(source); + + return count <= 0 ? + Empty() : + TakeRangeFromEndIterator(source, isStartIndexFromEnd: true, startIndex: count, isEndIndexFromEnd: true, endIndex: 0, default); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/TakeWhile.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/TakeWhile.cs new file mode 100644 index 00000000000000..cf9e95b5047a74 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/TakeWhile.cs @@ -0,0 +1,164 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Returns elements from a sequence as long as a specified condition is true. + /// The type of the elements of source. + /// A sequence to return elements from. + /// A function to test each element for a condition. + /// + /// An that contains the elements from the + /// input sequence that occur before the element at which the test no longer passes. + /// + /// is . + /// is . + public static IAsyncEnumerable TakeWhile( + this IAsyncEnumerable source, + Func predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, Func predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (!predicate(element)) + { + break; + } + + yield return element; + } + } + } + + /// Returns elements from a sequence as long as a specified condition is true. + /// The type of the elements of source. + /// A sequence to return elements from. + /// A function to test each element for a condition. + /// + /// An that contains the elements from the + /// input sequence that occur before the element at which the test no longer passes. + /// + /// is . + /// is . + public static IAsyncEnumerable TakeWhile( + this IAsyncEnumerable source, + Func> predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (!await predicate(element, cancellationToken).ConfigureAwait(false)) + { + break; + } + + yield return element; + } + } + } + + /// + /// Returns elements from a sequence as long as a specified condition is true. + /// The element's index is used in the logic of the predicate function. + /// + /// The type of the elements of source. + /// A sequence to return elements from. + /// A function to test each element for a condition. + /// + /// An that contains the elements from the + /// input sequence that occur before the element at which the test no longer passes. + /// + /// is . + /// is . + public static IAsyncEnumerable TakeWhile( + this IAsyncEnumerable source, + Func predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (!predicate(element, checked(++index))) + { + break; + } + + yield return element; + } + } + } + + /// + /// Returns elements from a sequence as long as a specified condition is true. + /// The element's index is used in the logic of the predicate function. + /// + /// The type of the elements of source. + /// A sequence to return elements from. + /// A function to test each element for a condition. + /// + /// An that contains the elements from the + /// input sequence that occur before the element at which the test no longer passes. + /// + /// is . + /// is . + public static IAsyncEnumerable TakeWhile( + this IAsyncEnumerable source, + Func> predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (!await predicate(element, checked(++index), cancellationToken).ConfigureAwait(false)) + { + break; + } + + yield return element; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ThrowHelper.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ThrowHelper.cs new file mode 100644 index 00000000000000..4937dcb51d4ef9 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ThrowHelper.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace System.Linq +{ + internal static class ThrowHelper + { + internal static void ThrowIfNull([NotNull] object? argument, [CallerArgumentExpression(nameof(argument))] string? paramName = null) + { + if (argument is null) + { + ThrowArgumentNullException(paramName); + } + + [DoesNotReturn] + static void ThrowArgumentNullException(string? paramName) => throw new ArgumentNullException(paramName); + } + + internal static void ThrowIfNegative(int value, [CallerArgumentExpression(nameof(value))] string? paramName = null) + { + if (value < 0) + { + ThrowArgumentOutOfRangeException(paramName!); + } + } + + internal static void ThrowIfNegativeOrZero(int value, [CallerArgumentExpression(nameof(value))] string? paramName = null) + { + if (value <= 0) + { + ThrowArgumentOutOfRangeException(paramName!); + } + } + + [DoesNotReturn] + internal static void ThrowArgumentOutOfRangeException(string paramName) => throw new ArgumentOutOfRangeException(paramName); + + [DoesNotReturn] + internal static void ThrowMoreThanOneElementException() => throw new InvalidOperationException(SR.MoreThanOneElement); + + [DoesNotReturn] + internal static void ThrowMoreThanOneMatchException() => throw new InvalidOperationException(SR.MoreThanOneMatch); + + [DoesNotReturn] + internal static void ThrowNoElementsException() => throw new InvalidOperationException(SR.NoElements); + + [DoesNotReturn] + internal static void ThrowNoMatchException() => throw new InvalidOperationException(SR.NoMatch); + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToArrayAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToArrayAsync.cs new file mode 100644 index 00000000000000..6e5dd1e16d1188 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToArrayAsync.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Creates an array from an . + /// The type of the elements of source. + /// An to create an array from. + /// The to monitor for cancellation requests. The default is . + /// An array that contains the elements from the input sequence. + /// is . + public static ValueTask ToArrayAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask Impl( + ConfiguredCancelableAsyncEnumerable source) + { + List list = []; + await foreach (TSource element in source) + { + list.Add(element); + } + + return list.ToArray(); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToAsyncEnumerable.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToAsyncEnumerable.cs new file mode 100644 index 00000000000000..7f867a5aa51fa8 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToAsyncEnumerable.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Creates a new that iterates through . + /// The type of the elements of source. + /// An of the elements to enumerate. + /// An containing the sequence of elements from . + /// is . + /// + /// Each iteration through the resulting will iterate through the . + /// + public static IAsyncEnumerable ToAsyncEnumerable( + this IEnumerable source) + { + ThrowHelper.ThrowIfNull(source); + + return source switch + { + TSource[] array => FromArray(array), + List list => FromList(list), + IList list => FromIList(list), + _ => FromIterator(source), + }; + + static async IAsyncEnumerable FromArray(TSource[] source) + { + for (int i = 0; ; i++) + { + int localI = i; + TSource[] localSource = source; + if ((uint)localI >= (uint)localSource.Length) + { + break; + } + yield return localSource[localI]; + } + } + + static async IAsyncEnumerable FromList(List source) + { + for (int i = 0; i < source.Count; i++) + { + yield return source[i]; + } + } + + static async IAsyncEnumerable FromIList(IList source) + { + int count = source.Count; + for (int i = 0; i < count; i++) + { + yield return source[i]; + } + } + + static async IAsyncEnumerable FromIterator(IEnumerable source) + { + foreach (TSource element in source) + { + yield return element; + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToDictionaryAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToDictionaryAsync.cs new file mode 100644 index 00000000000000..afe1ce05d410d2 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToDictionaryAsync.cs @@ -0,0 +1,236 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Creates a from an according to specified key comparer. + /// + /// The type of the keys from elements of + /// The type of the values from elements of + /// The to create a from. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values from . + /// is . + /// contains one or more duplicate keys (via the returned task). + public static ValueTask> ToDictionaryAsync( + this IAsyncEnumerable> source, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), comparer); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable> source, + IEqualityComparer? comparer) + { + Dictionary d = new Dictionary(comparer); + await foreach (KeyValuePair element in source) + { + d.Add(element.Key, element.Value); + } + + return d; + } + } + + /// + /// Creates a from an according to specified key comparer. + /// + /// The type of the keys from elements of + /// The type of the values from elements of + /// The to create a from. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values from . + /// is . + /// contains one or more duplicate keys (via the returned task). + public static ValueTask> ToDictionaryAsync( + this IAsyncEnumerable<(TKey Key, TValue Value)> source, IEqualityComparer? comparer = null, CancellationToken cancellationToken = default) where TKey : notnull => + source.ToDictionaryAsync(vt => vt.Key, vt => vt.Value, comparer, cancellationToken); + + /// + /// Creates a from an + /// according to a specified key selector function. + /// + /// The type of the elements of source. + /// The type of the keys returned by . + /// An to create a from. + /// A function to extract a key from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values. + /// is . + /// is . + /// contains one or more duplicate keys (via the returned task). + public static ValueTask> ToDictionaryAsync( + this IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), keySelector, comparer); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer) + { + Dictionary d = new(comparer); + await foreach (TSource element in source) + { + d.Add(keySelector(element), element); + } + return d; + } + } + + /// + /// Creates a from an + /// according to a specified key selector function. + /// + /// The type of the elements of source. + /// The type of the keys returned by . + /// An to create a from. + /// A function to extract a key from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values. + /// is . + /// is . + /// contains one or more duplicate keys (via the returned task). + public static ValueTask> ToDictionaryAsync( + this IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer, cancellationToken); + + static async ValueTask> Impl( + IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer, + CancellationToken cancellationToken) + { + Dictionary d = new(comparer); + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + d.Add(await keySelector(element, cancellationToken).ConfigureAwait(false), element); + } + return d; + } + } + + /// + /// Creates a from an "/> + /// according to specified key selector and element selector functions. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the value returned by . + /// An to create a from. + /// A function to extract a key from each element. + /// A transform function to produce a result element value from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains values of type selected from the input sequence. + /// is . + /// is . + /// is . + /// contains one or more duplicate keys (via the returned task). + public static ValueTask> ToDictionaryAsync( + this IAsyncEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), keySelector, elementSelector, comparer); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer? comparer) + { + Dictionary d = new(comparer); + await foreach (TSource element in source) + { + d.Add(keySelector(element), elementSelector(element)); + } + + return d; + } + } + + /// + /// Creates a from an "/> + /// according to specified key selector and element selector functions. + /// + /// The type of the elements of source. + /// The type of the key returned by . + /// The type of the value returned by . + /// An to create a from. + /// A function to extract a key from each element. + /// A transform function to produce a result element value from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains values of type selected from the input sequence. + /// is . + /// is . + /// is . + /// contains one or more duplicate keys (via the returned task). + public static ValueTask> ToDictionaryAsync( + this IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) where TKey : notnull + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + + return Impl(source, keySelector, elementSelector, comparer, cancellationToken); + + static async ValueTask> Impl( + IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + IEqualityComparer? comparer, + CancellationToken cancellationToken) + { + Dictionary d = new(comparer); + await foreach (TSource element in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + d.Add( + await keySelector(element, cancellationToken).ConfigureAwait(false), + await elementSelector(element, cancellationToken).ConfigureAwait(false)); + } + + return d; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToHashSetAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToHashSetAsync.cs new file mode 100644 index 00000000000000..0aed533e28c762 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToHashSetAsync.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Creates a from an . + /// The type of the elements of . + /// An to create a from. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains values of type selected from the input sequence. + /// is . + public static ValueTask> ToHashSetAsync( + this IAsyncEnumerable source, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), comparer); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable source, + IEqualityComparer? comparer) + { + HashSet set = new(comparer); + await foreach (TSource element in source) + { + set.Add(element); + } + + return set; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToListAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToListAsync.cs new file mode 100644 index 00000000000000..c011f94b90f608 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToListAsync.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Creates a list from an . + /// The type of the elements of source. + /// An to create a list from. + /// The to monitor for cancellation requests. The default is . + /// A list that contains the elements from the input sequence. + /// is . + public static ValueTask> ToListAsync( + this IAsyncEnumerable source, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false)); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable source) + { + List list = []; + await foreach (TSource element in source) + { + list.Add(element); + } + + return list; + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToLookupAsync.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToLookupAsync.cs new file mode 100644 index 00000000000000..17cd7f1475bbca --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/ToLookupAsync.cs @@ -0,0 +1,363 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Creates a from an + /// according to a specified key selector function. + /// + /// The type of the elements of . + /// The type of the key returned by . + /// The to create a from. + /// A function to extract a key from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values. + /// is . + /// is . + public static ValueTask> ToLookupAsync( + this IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), keySelector, comparer); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer) + { + var lookup = new AsyncLookup(comparer); + await foreach (TSource item in source) + { + lookup.GetGrouping(keySelector(item), create: true)!.Add(item); + } + + return lookup; + } + } + + /// + /// Creates a from an + /// according to a specified key selector function. + /// + /// The type of the elements of . + /// The type of the key returned by . + /// The to create a from. + /// A function to extract a key from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values. + /// is . + /// is . + public static ValueTask> ToLookupAsync( + this IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(source, keySelector, comparer, cancellationToken); + + static async ValueTask> Impl( + IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer, + CancellationToken cancellationToken) + { + var lookup = new AsyncLookup(comparer); + await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + lookup.GetGrouping(await keySelector(item, cancellationToken).ConfigureAwait(false), create: true)!.Add(item); + } + + return lookup; + } + } + + /// + /// Creates a from an + /// according to a specified key selector function and element selector functions. + /// + /// The type of the elements of . + /// The type of the key returned by . + /// The type of the value returned by . + /// The to create a from. + /// A function to extract a key from each element. + /// A transform function to produce a result element value from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values. + /// is . + /// is . + public static ValueTask> ToLookupAsync( + this IAsyncEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + + return Impl(source.WithCancellation(cancellationToken).ConfigureAwait(false), keySelector, elementSelector, comparer); + + static async ValueTask> Impl( + ConfiguredCancelableAsyncEnumerable source, + Func keySelector, + Func elementSelector, + IEqualityComparer? comparer) + { + var lookup = new AsyncLookup(comparer); + await foreach (TSource item in source) + { + lookup.GetGrouping(keySelector(item), create: true)!.Add(elementSelector(item)); + } + + return lookup; + } + } + + /// + /// Creates a from an + /// according to a specified key selector function and element selector functions. + /// + /// The type of the elements of . + /// The type of the key returned by . + /// The type of the value returned by . + /// The to create a from. + /// A function to extract a key from each element. + /// A transform function to produce a result element value from each element. + /// An to compare keys. + /// The to monitor for cancellation requests. The default is . + /// A that contains keys and values. + /// is . + /// is . + public static ValueTask> ToLookupAsync( + this IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + IEqualityComparer? comparer = null, + CancellationToken cancellationToken = default) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(keySelector); + ThrowHelper.ThrowIfNull(elementSelector); + + return Impl(source, keySelector, elementSelector, comparer, cancellationToken); + + static async ValueTask> Impl( + IAsyncEnumerable source, + Func> keySelector, + Func> elementSelector, + IEqualityComparer? comparer, + CancellationToken cancellationToken) + { + var lookup = new AsyncLookup(comparer); + await foreach (TSource item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + lookup.GetGrouping(await keySelector(item, cancellationToken).ConfigureAwait(false), create: true)! + .Add(await elementSelector(item, cancellationToken).ConfigureAwait(false)); + } + + return lookup; + } + } + + [DebuggerDisplay("Count = {Count}")] + private sealed class AsyncLookup : ILookup + { + private readonly IEqualityComparer _comparer; + private Grouping[] _groupings; + internal Grouping? _lastGrouping; + private int _count; + + internal AsyncLookup(IEqualityComparer? comparer) + { + _comparer = comparer ?? EqualityComparer.Default; + _groupings = new Grouping[7]; + } + + internal static async ValueTask> CreateForJoinAsync( + IAsyncEnumerable source, + Func keySelector, + IEqualityComparer? comparer, + CancellationToken cancellationToken) + { + Debug.Assert(source is not null); + Debug.Assert(keySelector is not null); + + AsyncLookup lookup = new(comparer); + await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + TKey key = keySelector(item); + if (key is not null) + { + lookup.GetGrouping(key, create: true)!.Add(item); + } + } + + return lookup; + } + + internal static async ValueTask> CreateForJoinAsync( + IAsyncEnumerable source, + Func> keySelector, + IEqualityComparer? comparer, + CancellationToken cancellationToken) + { + Debug.Assert(source is not null); + Debug.Assert(keySelector is not null); + + AsyncLookup lookup = new(comparer); + await foreach (TElement item in source.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + TKey key = await keySelector(item, cancellationToken).ConfigureAwait(false); + if (key is not null) + { + lookup.GetGrouping(key, create: true)!.Add(item); + } + } + + return lookup; + } + + public int Count => _count; + + public IEnumerable this[TKey key] => GetGrouping(key, create: false) ?? Enumerable.Empty(); + + public bool Contains(TKey key) => GetGrouping(key, create: false) is not null; + + public IEnumerator> GetEnumerator() + { + Grouping? g = _lastGrouping; + if (g is not null) + { + do + { + g = g._next; + + Debug.Assert(g is not null); + yield return g; + } + while (g != _lastGrouping); + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + internal Grouping? GetGrouping(TKey key, bool create) + { + int hashCode = (key is null) ? 0 : _comparer.GetHashCode(key) & 0x7FFFFFFF; + for (Grouping? g = _groupings[(uint)hashCode % _groupings.Length]; g is not null; g = g._hashNext) + { + if (g._hashCode == hashCode && _comparer.Equals(g._key, key)) + { + return g; + } + } + + if (create) + { + if (_count == _groupings.Length) + { + Resize(); + } + + int index = hashCode % _groupings.Length; + Grouping g = new(key, hashCode) + { + _hashNext = _groupings[index] + }; + _groupings[index] = g; + if (_lastGrouping is null) + { + g._next = g; + } + else + { + g._next = _lastGrouping._next; + _lastGrouping._next = g; + } + + _lastGrouping = g; + _count++; + return g; + } + + return null; + } + + private void Resize() + { + int newSize = checked((_count * 2) + 1); + Grouping[] newGroupings = new Grouping[newSize]; + Grouping g = _lastGrouping!; + do + { + g = g._next!; + int index = g._hashCode % newSize; + g._hashNext = newGroupings[index]; + newGroupings[index] = g; + } + while (g != _lastGrouping); + + _groupings = newGroupings; + } + + internal IEnumerable ApplyResultSelector( + Func, TResult> resultSelector) + { + Grouping? g = _lastGrouping; + if (g is not null) + { + do + { + g = g._next; + + Debug.Assert(g is not null); + g.Trim(); + yield return resultSelector(g._key, g._elements); + } + while (g != _lastGrouping); + } + } + + internal async IAsyncEnumerable ApplyResultSelector( + Func, CancellationToken, ValueTask> resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + Grouping? g = _lastGrouping; + if (g is not null) + { + do + { + g = g._next; + + Debug.Assert(g is not null); + g.Trim(); + yield return await resultSelector(g._key, g._elements, cancellationToken).ConfigureAwait(false); + } + while (g != _lastGrouping); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Union.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Union.cs new file mode 100644 index 00000000000000..17b9e857f3599d --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Union.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Produces the set union of two sequences. + /// + /// An whose distinct elements form the first set for the union. + /// An whose distinct elements form the second set for the union. + /// An to compare keys. + /// An that contains the elements from both input sequences, excluding duplicates. + /// is . + /// is . + public static IAsyncEnumerable Union( + this IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + + return Impl(first, second, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set = new(comparer); + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(element)) + { + yield return element; + } + } + + await foreach (TSource element in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(element)) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/UnionBy.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/UnionBy.cs new file mode 100644 index 00000000000000..ccf80ed292c923 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/UnionBy.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Produces the set union of two sequences according to a specified key selector function. + /// The type of the elements of the input sequences. + /// The type of key to identify elements by. + /// An whose distinct elements form the first set for the union. + /// An whose distinct elements form the second set for the union. + /// A function to extract the key for each element. + /// The to compare values. + /// An that contains the elements from both input sequences, excluding duplicates. + /// is . + /// is . + public static IAsyncEnumerable UnionBy( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(first, second, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set = new(comparer); + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + + await foreach (TSource element in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + } + } + + /// Produces the set union of two sequences according to a specified key selector function. + /// The type of the elements of the input sequences. + /// The type of key to identify elements by. + /// An whose distinct elements form the first set for the union. + /// An whose distinct elements form the second set for the union. + /// A function to extract the key for each element. + /// The to compare values. + /// An that contains the elements from both input sequences, excluding duplicates. + /// is . + /// is . + public static IAsyncEnumerable UnionBy( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func> keySelector, + IEqualityComparer? comparer = null) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(keySelector); + + return Impl(first, second, keySelector, comparer, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func> keySelector, + IEqualityComparer? comparer, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HashSet set = new(comparer); + + await foreach (TSource element in first.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(await keySelector(element, cancellationToken).ConfigureAwait(false))) + { + yield return element; + } + } + + await foreach (TSource element in second.WithCancellation(cancellationToken).ConfigureAwait(false)) + { + if (set.Add(await keySelector(element, cancellationToken).ConfigureAwait(false))) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Where.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Where.cs new file mode 100644 index 00000000000000..e915df69ac4166 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Where.cs @@ -0,0 +1,151 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// Filters a sequence of values based on a predicate. + /// The type of the elements of . + /// An to filter. + /// A function to test each element for a condition. + /// An that contains elements from the input sequence that satisfy the condition. + /// is . + /// is . + public static IAsyncEnumerable Where( // satisfies the C# query-expression pattern + this IAsyncEnumerable source, + Func predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken)) + { + if (predicate(element)) + { + yield return element; + } + } + } + } + + /// Filters a sequence of values based on a predicate. + /// The type of the elements of . + /// An to filter. + /// A function to test each element for a condition. + /// An that contains elements from the input sequence that satisfy the condition. + /// is . + /// is . + public static IAsyncEnumerable Where( + this IAsyncEnumerable source, + Func> predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (TSource element in source.WithCancellation(cancellationToken)) + { + if (await predicate(element, cancellationToken).ConfigureAwait(false)) + { + yield return element; + } + } + } + } + + /// + /// Filters a sequence of values based on a predicate. + /// Each element's index is used in the logic of the predicate function. + /// + /// The type of the elements of . + /// An to filter. + /// + /// A function to test each element for a condition; the second parameter + /// of the function represents the index of the source element. + /// + /// An that contains elements from the input sequence that satisfy the condition. + /// is . + /// is . + public static IAsyncEnumerable Where( + this IAsyncEnumerable source, + Func predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken)) + { + if (predicate(element, checked(++index))) + { + yield return element; + } + } + } + } + + /// + /// Filters a sequence of values based on a predicate. + /// Each element's index is used in the logic of the predicate function. + /// + /// The type of the elements of . + /// An to filter. + /// + /// A function to test each element for a condition; the second parameter + /// of the function represents the index of the source element. + /// + /// An that contains elements from the input sequence that satisfy the condition. + /// is . + /// is . + public static IAsyncEnumerable Where( + this IAsyncEnumerable source, + Func> predicate) + { + ThrowHelper.ThrowIfNull(source); + ThrowHelper.ThrowIfNull(predicate); + + return Impl(source, predicate, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable source, + Func> predicate, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + int index = -1; + await foreach (TSource element in source.WithCancellation(cancellationToken)) + { + if (await predicate(element, checked(++index), cancellationToken).ConfigureAwait(false)) + { + yield return element; + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Zip.cs b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Zip.cs new file mode 100644 index 00000000000000..d4e7e6b2e6c0e4 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/src/System/Linq/Zip.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Linq +{ + public static partial class AsyncEnumerable + { + /// + /// Applies a specified function to the corresponding elements of two sequences, + /// producing a sequence of the results. + /// + /// The type of the elements of the first input sequence. + /// The type of the elements of the second input sequence. + /// The type of the elements of the result sequence. + /// The first sequence to merge. + /// The second sequence to merge. + /// A function that specifies how to merge the elements from the two sequences. + /// An that contains merged elements of two input sequences. + /// is . + /// is . + /// is . + public static IAsyncEnumerable Zip( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func resultSelector) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(first, second, resultSelector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e1 = first.GetAsyncEnumerator(cancellationToken); + try + { + IAsyncEnumerator e2 = second.GetAsyncEnumerator(cancellationToken); + try + { + while (await e1.MoveNextAsync().ConfigureAwait(false) && + await e2.MoveNextAsync().ConfigureAwait(false)) + { + yield return resultSelector(e1.Current, e2.Current); + } + } + finally + { + await e2.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + await e1.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// + /// Applies a specified function to the corresponding elements of two sequences, + /// producing a sequence of the results. + /// + /// The type of the elements of the first input sequence. + /// The type of the elements of the second input sequence. + /// The type of the elements of the result sequence. + /// The first sequence to merge. + /// The second sequence to merge. + /// A function that specifies how to merge the elements from the two sequences. + /// An that contains merged elements of two input sequences. + /// is . + /// is . + /// is . + public static IAsyncEnumerable Zip( + this IAsyncEnumerable first, + IAsyncEnumerable second, + Func> resultSelector) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(resultSelector); + + return Impl(first, second, resultSelector, default); + + static async IAsyncEnumerable Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + Func> resultSelector, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e1 = first.GetAsyncEnumerator(cancellationToken); + try + { + IAsyncEnumerator e2 = second.GetAsyncEnumerator(cancellationToken); + try + { + while (await e1.MoveNextAsync().ConfigureAwait(false) && + await e2.MoveNextAsync().ConfigureAwait(false)) + { + yield return await resultSelector(e1.Current, e2.Current, cancellationToken).ConfigureAwait(false); + } + } + finally + { + await e2.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + await e1.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Produces a sequence of tuples with elements from the two specified sequences. + /// The type of the elements of the first input sequence. + /// The type of the elements of the second input sequence. + /// The first sequence to merge. + /// The second sequence to merge. + /// A sequence of tuples with elements taken from the first and second sequences, in that order. + /// is . + /// is . + public static IAsyncEnumerable<(TFirst First, TSecond Second)> Zip( + this IAsyncEnumerable first, + IAsyncEnumerable second) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + + return Impl(first, second, default); + + static async IAsyncEnumerable<(TFirst First, TSecond Second)> Impl( + IAsyncEnumerable first, + IAsyncEnumerable second, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e1 = first.GetAsyncEnumerator(cancellationToken); + try + { + IAsyncEnumerator e2 = second.GetAsyncEnumerator(cancellationToken); + try + { + while (await e1.MoveNextAsync().ConfigureAwait(false) && + await e2.MoveNextAsync().ConfigureAwait(false)) + { + yield return (e1.Current, e2.Current); + } + } + finally + { + await e2.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + await e1.DisposeAsync().ConfigureAwait(false); + } + } + } + + /// Produces a sequence of tuples with elements from the three specified sequences. + /// The type of the elements of the first input sequence. + /// The type of the elements of the second input sequence. + /// The type of the elements of the third input sequence. + /// The first sequence to merge. + /// The second sequence to merge. + /// The third sequence to merge. + /// A sequence of tuples with elements taken from the first, second, and third sequences, in that order. + /// is . + /// is . + /// is . + public static IAsyncEnumerable<(TFirst First, TSecond Second, TThird Third)> Zip( + this IAsyncEnumerable first, + IAsyncEnumerable second, + IAsyncEnumerable third) + { + ThrowHelper.ThrowIfNull(first); + ThrowHelper.ThrowIfNull(second); + ThrowHelper.ThrowIfNull(third); + + return Impl(first, second, third, default); + + static async IAsyncEnumerable<(TFirst First, TSecond Second, TThird)> Impl( + IAsyncEnumerable first, IAsyncEnumerable second, IAsyncEnumerable third, [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator e1 = first.GetAsyncEnumerator(cancellationToken); + try + { + IAsyncEnumerator e2 = second.GetAsyncEnumerator(cancellationToken); + try + { + IAsyncEnumerator e3 = third.GetAsyncEnumerator(cancellationToken); + try + { + while (await e1.MoveNextAsync().ConfigureAwait(false) && + await e2.MoveNextAsync().ConfigureAwait(false) && + await e3.MoveNextAsync().ConfigureAwait(false)) + { + yield return (e1.Current, e2.Current, e3.Current); + } + } + finally + { + await e3.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + await e2.DisposeAsync().ConfigureAwait(false); + } + } + finally + { + await e1.DisposeAsync().ConfigureAwait(false); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AggregateAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AggregateAsyncTests.cs new file mode 100644 index 00000000000000..b7fe3f07a0662a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AggregateAsyncTests.cs @@ -0,0 +1,308 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class AggregateAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateAsync(null, (x, y) => x + y)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateAsync(null, async (x, y, ct) => x + y)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateAsync(null, 42, (x, y) => x + y)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateAsync(null, 42, async (x, y, ct) => x + y)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), 42, (Func)null)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), 42, (Func>)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateAsync(null, 42, (x, y) => x + y, x => x * 2)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateAsync(null, 42, async (x, y, ct) => x + y, async (x, ct) => x * 2)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), 42, (Func)null, x => x * 2)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), 42, (Func>)null, async (x, ct) => x * 2)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), 42, (x, y) => x + y, (Func)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.AggregateAsync(AsyncEnumerable.Empty(), 42, async (x, y, ct) => x + y, (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable_Int32(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + if (values.Length > 0) + { + Assert.Equal( + values.Aggregate((x, y) => x + (2 * y)), + await source.AggregateAsync((x, y) => x + (2 * y))); + + Assert.Equal( + values.Aggregate((x, y) => x + (2 * y)), + await source.AggregateAsync(async (x, y, ct) => + { + await Task.Yield(); + return x + (2 * y); + })); + } + else + { + Assert.Throws(() => values.Aggregate((x, y) => x + (2 * y))); + await Assert.ThrowsAsync(async () => await source.AggregateAsync((x, y) => x + (2 * y))); + await Assert.ThrowsAsync(async () => await source.AggregateAsync(async (x, y, ct) => x + (2 * y))); + } + + Assert.Equal( + values.Aggregate(42, (x, y) => x + (2 * y)), + await source.AggregateAsync(42, (x, y) => x + (2 * y))); + + Assert.Equal( + values.Aggregate(42, (x, y) => x + (2 * y)), + await source.AggregateAsync(42, async (x, y, ct) => + { + await Task.Yield(); + return x + (2 * y); + })); + + Assert.Equal( + values.Aggregate(42, (x, y) => x + (2 * y), x => x * 2), + await source.AggregateAsync(42, (x, y) => x + (2 * y), x => x * 2)); + + Assert.Equal( + values.Aggregate(42, (x, y) => x + (2 * y), x => x * 2), + await source.AggregateAsync(42, async (x, y, ct) => + { + await Task.Yield(); + return x + (2 * y); + }, async (x, ct) => + { + await Task.Yield(); + return x * 2; + })); + } + } + + public static IEnumerable VariousValues_MatchesEnumerable_String_MemberData() + { + yield return new object[] { new string[0] }; + yield return new object[] { new string[] { "1" } }; + yield return new object[] { new string[] { "2", "4", "8" } }; + yield return new object[] { new string[] { "-1", "2", "5", "6", "7", "8" } }; + } + + [Theory] + [MemberData(nameof(VariousValues_MatchesEnumerable_String_MemberData))] + public async Task VariousValues_MatchesEnumerable_String(string[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + if (values.Length > 0) + { + Assert.Equal( + values.Aggregate((x, y) => x + y + y), + await source.AggregateAsync((x, y) => x + y + y)); + + Assert.Equal( + values.Aggregate((x, y) => x + y + y), + await source.AggregateAsync(async (x, y, ct) => + { + await Task.Yield(); + return x + y + y; + })); + } + else + { + Assert.Throws(() => values.Aggregate((x, y) => x + y + y)); + await Assert.ThrowsAsync(async () => await source.AggregateAsync((x, y) => x + y + y)); + } + + Assert.Equal( + values.Aggregate((string)null, (x, y) => x + y + y), + await source.AggregateAsync((string)null, (x, y) => x + y + y)); + + Assert.Equal( + values.Aggregate((string)null, (x, y) => x + y + y), + await source.AggregateAsync((string)null, async (x, y, ct) => + { + await Task.Yield(); + return x + y + y; + })); + + Assert.Equal( + values.Aggregate((string)null, (x, y) => x + y + y, x => x + x), + await source.AggregateAsync((string)null, (x, y) => x + y + y, x => x + x)); + + Assert.Equal( + values.Aggregate((string)null, (x, y) => x + y + y, x => x + x), + await source.AggregateAsync((string)null, async (x, y, ct) => + { + await Task.Yield(); + return x + y + y; + }, async (x, ct) => + { + await Task.Yield(); + return x + x; + })); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AggregateAsync((x, y) => + { + cts.Cancel(); + return x + y; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AggregateAsync(async (x, y, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x + y; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AggregateAsync(42, (x, y) => + { + cts.Cancel(); + return x + y; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AggregateAsync(42, async (x, y, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x + y; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AggregateAsync(42, (x, y) => + { + cts.Cancel(); + return x + y; + }, x => x, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AggregateAsync(42, async (x, y, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x + y; + }, async (x, ct) => x, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int funcCount, resultCount; + + foreach (bool useAsync in TrueFalseBools) + { + funcCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal(30, useAsync ? + await source.AggregateAsync((x, y) => { funcCount++; return x + y; }) : + await source.AggregateAsync(async (x, y, ct) => { funcCount++; return x + y; })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(3, funcCount); + + funcCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + await (useAsync ? + source.AggregateAsync((x, y) => { funcCount++; throw new Exception(); }) : + source.AggregateAsync(async (x, y, ct) => { funcCount++; throw new Exception(); })); + }); + Assert.Equal(2, source.MoveNextAsyncCount); + Assert.Equal(2, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(1, funcCount); + + funcCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal(72, useAsync ? + await source.AggregateAsync(42, (x, y) => { funcCount++; return x + y; }) : + await source.AggregateAsync(42, async (x, y, ct) => { funcCount++; return x + y; })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, funcCount); + + funcCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + await (useAsync ? + source.AggregateAsync(42, (x, y) => { funcCount++; throw new Exception(); }) : + source.AggregateAsync(42, async (x, y, ct) => { funcCount++; throw new Exception(); })); + }); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(1, funcCount); + + funcCount = resultCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal(144, useAsync ? + await source.AggregateAsync(42, (x, y) => { funcCount++; return x + y; }, x => { resultCount++; return x * 2; }) : + await source.AggregateAsync(42, async (x, y, ct) => { funcCount++; return x + y; }, async (x, ct) => { resultCount++; return x * 2; })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, funcCount); + Assert.Equal(1, resultCount); + + funcCount = resultCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + await (useAsync ? + source.AggregateAsync(42, (x, y) => { funcCount++; throw new Exception(); }, x => { resultCount++; return x * 2; }) : + source.AggregateAsync(42, async (x, y, ct) => { funcCount++; throw new Exception(); }, async (x, ct) => { resultCount++; return x * 2; })); + }); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(1, funcCount); + Assert.Equal(0, resultCount); + + funcCount = resultCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + await (useAsync ? + source.AggregateAsync(42, (x, y) => { funcCount++; return x + y; }, x => { resultCount++; throw new Exception(); }) : + source.AggregateAsync(42, async (x, y, ct) => { funcCount++; return x + y; }, async (x, ct) => { resultCount++; throw new Exception(); })); + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, funcCount); + Assert.Equal(1, resultCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AggregateByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AggregateByTests.cs new file mode 100644 index 00000000000000..3db27d3f02219a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AggregateByTests.cs @@ -0,0 +1,181 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class AggregateByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateBy((IAsyncEnumerable)null, x => x, 42, (x, y) => x + y)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), (Func)null, 42, (x, y) => x + y)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), x => x, 42, (Func)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateBy((IAsyncEnumerable)null, async (x, ct) => x, 42, async (x, y, ct) => x + y)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), (Func>)null, 42, async (x, y, ct) => x + y)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), async (x, ct) => x, 42, (Func>)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateBy((IAsyncEnumerable)null, x => x, x => x, (x, y) => x + y)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), (Func)null, x => x, (x, y) => x + y)); + AssertExtensions.Throws("seedSelector", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), x => x, (Func)null, (x, y) => x + y)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), x => x, x => x, (Func)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.AggregateBy((IAsyncEnumerable)null, async (x, ct) => x, async (x, ct) => x, async (x, y, ct) => x + y)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), (Func>)null, async (x, ct) => x, async (x, y, ct) => x + y)); + AssertExtensions.Throws("seedSelector", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), async (x, ct) => x, (Func>)null, async (x, y, ct) => x + y)); + AssertExtensions.Throws("func", () => AsyncEnumerable.AggregateBy(AsyncEnumerable.Empty(), async (x, ct) => x, async (x, ct) => x, (Func>)null)); + } + + public static IEnumerable VariousValues_MatchesEnumerable_String_MemberData() + { + yield return new object[] { new string[0] }; + yield return new object[] { new string[] { "1" } }; + yield return new object[] { new string[] { "2", "4", "8" } }; + yield return new object[] { new string[] { "12", "4", "8" } }; + yield return new object[] { new string[] { "12", "13", "14", "15", "22", "23", "24" } }; + yield return new object[] { new string[] { "-1", "2", "5", "6", "7", "8" } }; + } + +#if NET + [Theory] + [MemberData(nameof(VariousValues_MatchesEnumerable_String_MemberData))] + public async Task VariousValues_MatchesEnumerable_String(string[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.AggregateBy(x => x[0], "", (x, y) => x + y).ToArray(), + await source.AggregateBy(x => x[0], "", (x, y) => x + y).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x, "", (x, y) => x + y).ToArray(), + await source.AggregateBy(async (x, ct) => x, "", async (x, y, ct) => x + y).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x[0], x => x.ToString() + x, (x, y) => x + y).ToArray(), + await source.AggregateBy(x => x[0], x => x.ToString() + x, (x, y) => x + y).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x, x => x.ToString() + x, (x, y) => x + y).ToArray(), + await source.AggregateBy(async (x, ct) => x, async (x, ct) => x.ToString() + x, async (x, y, ct) => x + y).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x[0], "", (x, y) => x + y, OddEvenComparer).ToArray(), + await source.AggregateBy(x => x[0], "", (x, y) => x + y, OddEvenComparer).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x, "", (x, y) => x + y, LengthComparer).ToArray(), + await source.AggregateBy(async (x, ct) => x, "", async (x, y, ct) => x + y, LengthComparer).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x[0], x => x.ToString() + x, (x, y) => x + y, OddEvenComparer).ToArray(), + await source.AggregateBy(x => x[0], x => x.ToString() + x, (x, y) => x + y, OddEvenComparer).ToArrayAsync()); + + Assert.Equal( + values.AggregateBy(x => x, x => x.ToString() + x, (x, y) => x + y, LengthComparer).ToArray(), + await source.AggregateBy(async (x, ct) => x, async (x, ct) => x.ToString() + x, async (x, y, ct) => x + y, LengthComparer).ToArrayAsync()); + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.AggregateBy(x => + { + cts.Cancel(); + return x; + }, 42, (x, y) => x + y).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.AggregateBy(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }, 42, async (x, y, ct) => x + y).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.AggregateBy(x => + { + cts.Cancel(); + return x; + }, x => x, (x, y) => x + y).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.AggregateBy(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }, async (x, ct) => x, async (x, y, ct) => x + y).WithCancellation(cts.Token)); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int keySelectorCount, funcCount, seedSelectorCount; + + foreach (bool useAsync in TrueFalseBools) + { + keySelectorCount = funcCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal([ + new(2, 44), + new(4, 46), + new(8, 50), + new(16, 58), + ], + useAsync ? + await source.AggregateBy(x => { keySelectorCount++; return x; }, 42, (x, y) => { funcCount++; return x + y; }).ToArrayAsync() : + await source.AggregateBy(async (x, ct) => { keySelectorCount++; return x; }, 42, async (x, y, ct) => { funcCount++; return x + y; }).ToArrayAsync()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(4, funcCount); + + keySelectorCount = funcCount = seedSelectorCount = 0; + source = CreateSource(2, 2, 2, 16).Track(); + Assert.Equal([ + new(2, 48), + new(16, 58), + ], + useAsync ? + await source.AggregateBy(x => { keySelectorCount++; return x; }, x => { seedSelectorCount++; return 42; }, (x, y) => { funcCount++; return x + y; }).ToArrayAsync() : + await source.AggregateBy(async (x, ct) => { keySelectorCount++; return x; }, async (x, ct) => { seedSelectorCount++; return 42; }, async (x, y, ct) => { funcCount++; return x + y; }).ToArrayAsync()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(2, seedSelectorCount); + Assert.Equal(4, funcCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AllAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AllAsyncTests.cs new file mode 100644 index 00000000000000..9eb776e4a3203b --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AllAsyncTests.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class AllAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.AllAsync(null, x => x % 2 == 0)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.AllAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.AllAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Func predicate = x => x % 2 == 0; + + Assert.Equal( + values.All(predicate), + await source.AllAsync(predicate)); + + Assert.Equal( + values.All(predicate), + await source.AllAsync(async (x, ct) => + { + await Task.Yield(); + return predicate(x); + })); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + CancellationTokenSource cts = new(); + cts.Cancel(); + await Assert.ThrowsAsync(async () => await source.AllAsync(x => x % 2 == 0, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AllAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x % 2 == 0; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + foreach (bool useAsync in TrueFalseBools) + { + source = CreateSource(2, 4, 8, 16).Track(); + Assert.True(useAsync ? + await source.AllAsync(x => x % 2 == 0) : + await source.AllAsync(async (x, ct) => x % 2 == 0)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + Assert.False(useAsync ? + await source.AllAsync(x => x < 4) : + await source.AllAsync(async (x, ct) => x < 4)); + Assert.Equal(2, source.MoveNextAsyncCount); + Assert.Equal(2, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + await (useAsync ? + source.AllAsync(x => throw new Exception()) : + source.AllAsync(async (x, ct) => throw new Exception())); + }); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AnyAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AnyAsyncTests.cs new file mode 100644 index 00000000000000..68f55274b3a2a7 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AnyAsyncTests.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class AnyAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.AnyAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AnyAsync(null, x => x % 2 == 0)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.AnyAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.AnyAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Func predicate = x => x % 2 == 0; + + Assert.Equal( + values.Any(), + await source.AnyAsync()); + + Assert.Equal( + values.Any(predicate), + await source.AnyAsync(predicate)); + + Assert.Equal( + values.Any(predicate), + await source.AnyAsync(async (x, ct) => + { + await Task.Yield(); + return predicate(x); + })); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + CancellationTokenSource cts = new(); + cts.Cancel(); + await Assert.ThrowsAsync(async () => await source.AnyAsync(x => x < 0, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.AnyAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x < 0; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + Assert.True(await source.AnyAsync()); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(0, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + foreach (bool useAsync in TrueFalseBools) + { + source = CreateSource(2, 4, 8, 16).Track(); + Assert.True(useAsync ? + await source.AnyAsync(x => x > 7) : + await source.AnyAsync(async (x, ct) => x > 7)); + Assert.Equal(3, source.MoveNextAsyncCount); + Assert.Equal(3, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + Assert.False(useAsync ? + await source.AnyAsync(x => x > 20) : + await source.AnyAsync(async (x, ct) => x > 20)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + await (useAsync ? + source.AnyAsync(x => throw new Exception()) : + source.AnyAsync(async(x, ct) => throw new Exception())); + }); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AppendTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AppendTests.cs new file mode 100644 index 00000000000000..9e2b745f16f5c3 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AppendTests.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class AppendTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Append(null, 42)); + } + +#if NET + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.Append(42), + source.Append(42)); + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in source.Append(42).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(source.Append(42)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AsyncEnumerableTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AsyncEnumerableTests.cs new file mode 100644 index 00000000000000..f1000bb272420a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AsyncEnumerableTests.cs @@ -0,0 +1,151 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public abstract class AsyncEnumerableTests + { + protected static IAsyncEnumerable CreateSource(params T[] items) => + items.ToAsyncEnumerable().Yield(); + + protected static IEnumerable> CreateSources(params T[] items) + { + yield return items.ToAsyncEnumerable(); + yield return items.ToAsyncEnumerable().Yield(); + } + + protected static async Task ConsumeAsync(IAsyncEnumerable source) + { + await foreach (T item in source) { } + } + + protected static async Task ConsumeAsync(ConfiguredCancelableAsyncEnumerable source) + { + await foreach (T item in source) { } + } + + protected static void FillRandom(Random rand, int[] values) + { + for (int i = 0; i < values.Length; i++) + { + values[i] = rand.Next(); + } + } + + protected static void FillRandom(Random rand, string[] values) + { + for (int i = 0; i < values.Length; i++) + { + string s = Guid.NewGuid().ToString("N"); + values[i] = s.Substring(0, rand.Next(0, s.Length)); + } + } + + protected static async Task AssertEqual(IEnumerable expected, IAsyncEnumerable actual) + { + Assert.Equal( + expected.ToArray(), + await actual.ToArrayAsync()); + } + + protected static async Task AssertEqual(IAsyncEnumerable expected, IAsyncEnumerable actual) + { + await using IAsyncEnumerator e1 = expected.GetAsyncEnumerator(); + await using IAsyncEnumerator e2 = actual.GetAsyncEnumerator(); + + while (await e1.MoveNextAsync()) + { + Assert.True(await e2.MoveNextAsync()); + Assert.Equal(e1.Current, e2.Current); + } + + Assert.False(await e2.MoveNextAsync()); + } + + protected static IEqualityComparer CreateEqualityComparer(Func equals, Func getHashCode) => + new DelegateEqualityComparer(equals, getHashCode); + + protected static IEqualityComparer OddEvenComparer { get; } = CreateEqualityComparer((x, y) => x % 2 == y % 2, x => x % 2); + + protected static IEqualityComparer LengthComparer { get; } = CreateEqualityComparer((x, y) => x.Length == y.Length, x => x.Length); + + protected static bool[] TrueFalseBools { get; } = [true, false]; + + private sealed class DelegateEqualityComparer(Func equals, Func getHashCode) : IEqualityComparer + { + public bool Equals(T x, T y) => equals(x, y); + public int GetHashCode(T obj) => getHashCode(obj); + } + } + + public static class AsyncEnumerableTestExtensions + { + public static async IAsyncEnumerable Yield( + this IAsyncEnumerable source, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (T item in source.WithCancellation(cancellationToken)) + { + cancellationToken.ThrowIfCancellationRequested(); + await Task.Yield(); + yield return item; + } + } + + public static TrackingAsyncEnumerable Track(this IAsyncEnumerable source) => + new TrackingAsyncEnumerable(source); + + public static async IAsyncEnumerable AppendException(this IAsyncEnumerable source, Exception exception, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (T item in source.WithCancellation(cancellationToken)) + { + yield return item; + } + + throw exception; + } + } + + public sealed class TrackingAsyncEnumerable(IAsyncEnumerable source) : IAsyncEnumerable + { + private readonly IAsyncEnumerable _source = source; + + public int MoveNextAsyncCount { get; set; } + + public int CurrentCount { get; set; } + + public int DisposeAsyncCount { get; set; } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => + new TrackDisposeAsyncEnumerator(_source.GetAsyncEnumerator(cancellationToken), this); + + private sealed class TrackDisposeAsyncEnumerator(IAsyncEnumerator source, TrackingAsyncEnumerable parent) : IAsyncEnumerator + { + public T Current + { + get + { + parent.CurrentCount++; + return source.Current; + } + } + + public ValueTask MoveNextAsync() + { + parent.MoveNextAsyncCount++; + return source.MoveNextAsync(); + } + + public ValueTask DisposeAsync() + { + parent.DisposeAsyncCount++; + return source.DisposeAsync(); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/AverageAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/AverageAsyncTests.cs new file mode 100644 index 00000000000000..72bed15ad4aa65 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/AverageAsyncTests.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Globalization; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class AverageAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.AverageAsync((IAsyncEnumerable)null)); + } + [Fact] + public async Task EmptyInputs_NonNullable_Throws() + { + await Assert.ThrowsAsync(async () => await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + + Assert.Null(await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.AverageAsync(AsyncEnumerable.Empty())); + } + + [Theory] + [InlineData(new int[] { 0 })] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -int.MaxValue, int.MaxValue })] + [InlineData(new int[] { -1, -2, -3 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal(values.Select(i => (int)i).Average(), await source.Select(i => (int)i).AverageAsync()); + Assert.Equal(values.Select(i => (long)i).Average(), await source.Select(i => (long)i).AverageAsync()); + Assert.Equal(values.Select(i => (float)i).Average(), await source.Select(i => (float)i).AverageAsync()); + Assert.Equal(values.Select(i => (double)i).Average(), await source.Select(i => (double)i).AverageAsync()); + Assert.Equal(values.Select(i => (decimal)i).Average(), await source.Select(i => (decimal)i).AverageAsync()); + + Assert.Equal(values.Select(i => (int?)i).Average(), await source.Select(i => (int?)i).AverageAsync()); + Assert.Equal(values.Select(i => (long?)i).Average(), await source.Select(i => (long?)i).AverageAsync()); + Assert.Equal(values.Select(i => (float?)i).Average(), await source.Select(i => (float?)i).AverageAsync()); + Assert.Equal(values.Select(i => (double?)i).Average(), await source.Select(i => (double?)i).AverageAsync()); + Assert.Equal(values.Select(i => (decimal?)i).Average(), await source.Select(i => (decimal?)i).AverageAsync()); + + Assert.Equal(values.Select(i => (int?)i).Average(), await source.SelectMany(i => [i, null]).AverageAsync()); + Assert.Equal(values.Select(i => (long?)i).Average(), await source.SelectMany(i => [i, null]).AverageAsync()); + Assert.Equal(values.Select(i => (float?)i).Average(), await source.SelectMany(i => [i, null]).AverageAsync()); + Assert.Equal(values.Select(i => (double?)i).Average(), await source.SelectMany(i => [i, null]).AverageAsync()); + Assert.Equal(values.Select(i => (decimal?)i).Average(), await source.SelectMany(i => [i, null]).AverageAsync()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal)i).AverageAsync(new CancellationToken(true))); + + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int?)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long?)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float?)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double?)i).AverageAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal?)i).AverageAsync(new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.Select(i => (int)i).AverageAsync()); + await Validate(source => source.Select(i => (long)i).AverageAsync()); + await Validate(source => source.Select(i => (float)i).AverageAsync()); + await Validate(source => source.Select(i => (double)i).AverageAsync()); + await Validate(source => source.Select(i => (decimal)i).AverageAsync()); + + await Validate(source => source.Select(i => (int?)i).AverageAsync()); + await Validate(source => source.Select(i => (long?)i).AverageAsync()); + await Validate(source => source.Select(i => (float?)i).AverageAsync()); + await Validate(source => source.Select(i => (double?)i).AverageAsync()); + await Validate(source => source.Select(i => (decimal?)i).AverageAsync()); + + static async Task Validate(Func, ValueTask> factory) + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await factory(source); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).AppendException(new FormatException()).Track(); + await Assert.ThrowsAsync(async () => await factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/CastTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/CastTests.cs new file mode 100644 index 00000000000000..a7a88b8644d234 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/CastTests.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class CastTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Cast(null)); + } + + [Fact] + public async Task Empty_ProducesEmpty() + { + await AssertEqual(AsyncEnumerable.Empty(), AsyncEnumerable.Empty().Cast()); + await AssertEqual(AsyncEnumerable.Empty(), AsyncEnumerable.Empty().Cast()); + } + + [Fact] + public async Task NullAndNonNull_IncludesNulls() + { + await AssertEqual(["2", null, "8", null], CreateSource("2", null, "8", null).Cast()); + await AssertEqual(["2", null, "8", null], CreateSource("2", null, "8", null).Cast()); + await AssertEqual(["2", null, "8", null], CreateSource("2", null, "8", null).Cast()); + await AssertEqual([2, null, 8, null], CreateSource(2, null, 8, null).Cast()); + await AssertEqual([2, 8], CreateSource(2, 8).Cast()); + } + + [Fact] + public async Task IncorrectType_Throws() + { + await Assert.ThrowsAsync(async () => await ConsumeAsync(CreateSource(2, 8).Cast())); + await Assert.ThrowsAsync(async () => await ConsumeAsync(CreateSource("2", "8").Cast())); + await Assert.ThrowsAsync(async () => await ConsumeAsync(CreateSource("2", "8").Cast())); + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource("2", null, "8", null); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (string item in source.Cast().WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource("1", "2", "3").Track(); + await ConsumeAsync(source.Cast()); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(3, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ChunkTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ChunkTests.cs new file mode 100644 index 00000000000000..aa900c15e633f4 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ChunkTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ChunkTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Chunk(null, 42)); + + AsyncEnumerable.Chunk(AsyncEnumerable.Empty(), 1); + AssertExtensions.Throws("size", () => AsyncEnumerable.Chunk(AsyncEnumerable.Empty(), 0)); + AssertExtensions.Throws("size", () => AsyncEnumerable.Chunk(AsyncEnumerable.Empty(), -1)); + } + +#if NET + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + Random rand = new(42); + foreach (int collectionSize in new[] { 0, 1, 10, 50 }) + { + foreach (int chunkSize in new[] { 1, 2, 3, 5, 60 }) + { + int[] ints = new int[collectionSize]; + FillRandom(rand, ints); + + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + IEnumerable chunksExpected = ints.Chunk(chunkSize); + IAsyncEnumerable chunksActual = source.Chunk(chunkSize); + + IEnumerator e1 = chunksExpected.GetEnumerator(); + IAsyncEnumerator e2 = chunksActual.GetAsyncEnumerator(); + + while (e1.MoveNext()) + { + Assert.True(await e2.MoveNextAsync()); + Assert.Equal(e1.Current, e2.Current); + } + + Assert.False(await e2.MoveNextAsync()); + + e1.Dispose(); + await e2.DisposeAsync(); + } + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int[] item in source.Chunk(2).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Track(); + await ConsumeAsync(source.Chunk(3)); + Assert.Equal(11, source.MoveNextAsyncCount); + Assert.Equal(10, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ConcatTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ConcatTests.cs new file mode 100644 index 00000000000000..b6d649300af7ae --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ConcatTests.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ConcatTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.Concat(null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.Concat(AsyncEnumerable.Empty(), null)); + } + + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] first, int[] second) + { + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.Concat(second), + firstSource.Concat(secondSource)); + + await AssertEqual( + second.Concat(first), + secondSource.Concat(firstSource)); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(1, 3, 5); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.Concat(second).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable first, second; + + first = CreateSource(2, 4, 8, 16).Track(); + second = CreateSource(1, 3, 5).Track(); + await ConsumeAsync(first.Concat(second)); + + Assert.Equal(5, first.MoveNextAsyncCount); + Assert.Equal(4, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ContainsAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ContainsAsyncTests.cs new file mode 100644 index 00000000000000..50500ccef72ed8 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ContainsAsyncTests.cs @@ -0,0 +1,87 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ContainsAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ContainsAsync(null, 42)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 5, 6, 7, 8, 2})] + public async Task VariousValues_MatchesEnumerable_Int32(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.Contains(2), + await source.ContainsAsync(2)); + + Assert.Equal( + values.Contains(-2, OddEvenComparer), + await source.ContainsAsync(-2, OddEvenComparer)); + } + } + + public static IEnumerable VariousValues_MatchesEnumerable_String_MemberData() + { + yield return new object[] { new string[0] }; + yield return new object[] { new string[] { "1" } }; + yield return new object[] { new string[] { "2", "4", "8" } }; + yield return new object[] { new string[] { "-1", "5", "6", "7", "8", "2", "12" } }; + } + + [Theory] + [MemberData(nameof(VariousValues_MatchesEnumerable_String_MemberData))] + public async Task VariousValues_MatchesEnumerable_String(string[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.Contains("2"), + await source.ContainsAsync("2")); + + Assert.Equal( + values.Contains("00", LengthComparer), + await source.ContainsAsync("00", LengthComparer)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(1, 3, 5); + await Assert.ThrowsAsync(async () => await source.ContainsAsync(5, comparer: null, new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(1, 3, 5).Track(); + Assert.False(await source.ContainsAsync(6)); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(3, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(1, 3, 5).Track(); + Assert.True(await source.ContainsAsync(1)); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/CountAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/CountAsyncTests.cs new file mode 100644 index 00000000000000..7ed1fe3258e7da --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/CountAsyncTests.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class CountAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.CountAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.CountAsync(null, x => x % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.CountAsync(null, async (x, ct) => x % 2 == 0)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.LongCountAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LongCountAsync(null, x => x % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LongCountAsync(null, async (x, ct) => x % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.CountAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.CountAsync(AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LongCountAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LongCountAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Func predicate = x => x % 2 == 0; + + Assert.Equal( + values.Count(), + await source.CountAsync()); + + Assert.Equal( + values.Count(predicate), + await source.CountAsync(predicate)); + + Assert.Equal( + values.All(predicate), + await source.AllAsync(async (x, ct) => + { + await Task.Yield(); + return predicate(x); + })); + + Assert.Equal( + values.LongCount(), + await source.LongCountAsync()); + + Assert.Equal( + values.LongCount(predicate), + await source.LongCountAsync(predicate)); + + Assert.Equal( + values.LongCount(predicate), + await source.LongCountAsync(async (x, ct) => + { + await Task.Yield(); + return predicate(x); + })); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.CountAsync(x => x < 0, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await source.LongCountAsync(x => x < 0, new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.CountAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x < 0; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.LongCountAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x < 0; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + foreach (bool useLong in TrueFalseBools) + { + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal(4, useLong ? await source.LongCountAsync() : await source.CountAsync()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(0, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + + foreach (bool useAsync in TrueFalseBools) + { + foreach (bool useLong in TrueFalseBools) + { + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal(2, (useAsync, useLong) switch + { + (true, true) => await source.LongCountAsync(async (x, ct) => x > 7), + (true, false) => await source.CountAsync(async (x, ct) => x > 7), + (false, true) => await source.LongCountAsync(x => x > 7), + (false, false) => await source.CountAsync(x => x > 7) + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + Assert.Equal(0, (useAsync, useLong) switch + { + (true, true) => await source.LongCountAsync(async (x, ct) => x > 20), + (true, false) => await source.CountAsync(async (x, ct) => x > 20), + (false, true) => await source.LongCountAsync(x => x > 20), + (false, false) => await source.CountAsync(x => x > 20) + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await Assert.ThrowsAsync(async () => + { + switch ((useAsync, useLong)) + { + case (true, true): await source.LongCountAsync((x, ct) => throw new Exception()); break; + case (true, false): await source.CountAsync((x, ct) => throw new Exception()); break; + case (false, true): await source.LongCountAsync(x => throw new Exception()); break; + case (false, false): await source.CountAsync(x => throw new Exception()); break; + } + }); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/CountByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/CountByTests.cs new file mode 100644 index 00000000000000..4ba4aff94fb9de --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/CountByTests.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class CountByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.CountBy((IAsyncEnumerable)null, x => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.CountBy(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.CountBy(AsyncEnumerable.Empty(), (Func>)null)); + } + +#if NET + [Fact] + public async Task VariousValues_MatchesEnumerable_Strings() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 1000 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.CountBy(x => x.Length), + source.CountBy(x => x.Length)); + + await AssertEqual( + values.CountBy(x => x.Length), + source.CountBy(async (x, ct) => x.Length)); + + await AssertEqual( + values.CountBy(x => x.Length, OddEvenComparer), + source.CountBy(x => x.Length, OddEvenComparer)); + + await AssertEqual( + values.CountBy(x => x.Length, OddEvenComparer), + source.CountBy(async (x, ct) => x.Length, OddEvenComparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.CountBy(x => + { + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.CountBy(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16, 2, 7, 8).Track(); + await ConsumeAsync(useAsync ? source.CountBy(x => x) : source.CountBy(async (x, ct) => x)); + Assert.Equal(8, source.MoveNextAsyncCount); + Assert.Equal(7, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/DefaultIfEmptyTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/DefaultIfEmptyTests.cs new file mode 100644 index 00000000000000..0fae6f6e0f20ff --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/DefaultIfEmptyTests.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class DefaultIfEmptyTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.DefaultIfEmpty(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.DefaultIfEmpty(null, 42)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.DefaultIfEmpty(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.DefaultIfEmpty(null, "")); + + _ = AsyncEnumerable.DefaultIfEmpty(AsyncEnumerable.Empty(), null); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.DefaultIfEmpty(), + source.DefaultIfEmpty()); + + await AssertEqual( + values.DefaultIfEmpty(42), + source.DefaultIfEmpty(42)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in source.DefaultIfEmpty().WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(source.DefaultIfEmpty()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = AsyncEnumerable.Empty().Track(); + await ConsumeAsync(source.DefaultIfEmpty(42)); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(0, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/DistinctByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/DistinctByTests.cs new file mode 100644 index 00000000000000..3a8b4fd64145b0 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/DistinctByTests.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class DistinctByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.DistinctBy((IAsyncEnumerable)null, x => x)); + AssertExtensions.Throws("source", () => AsyncEnumerable.DistinctBy((IAsyncEnumerable)null, async (x, ct) => x)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.DistinctBy(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.DistinctBy(AsyncEnumerable.Empty(), (Func>)null)); + } + +#if NET + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 1, 1, 1, 2, 2, 2, 2, 2 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8, 2, 4, 8, 2 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + foreach (IEqualityComparer comparer in new[] { null, EqualityComparer.Default, OddEvenComparer }) + { + await AssertEqual( + values.DistinctBy(x => x, comparer), + source.DistinctBy(x => x, comparer)); + + await AssertEqual( + values.DistinctBy(x => x, comparer), + source.DistinctBy(async (x, ct) => x, comparer)); + + await AssertEqual( + values.DistinctBy(x => x / 3, comparer), + source.DistinctBy(x => x / 3, comparer)); + + await AssertEqual( + values.DistinctBy(x => x / 3, comparer), + source.DistinctBy(async (x, ct) => x / 3, comparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.DistinctBy(x => x).WithCancellation(new CancellationToken(true))); + }); + + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.DistinctBy(x => + { + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.DistinctBy(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + int funcCount; + + funcCount = 0; + await ConsumeAsync(useAsync ? + source.DistinctBy(x => + { + funcCount++; + return x; + }) : + source.DistinctBy(async (x, ct) => + { + funcCount++; + return x; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, funcCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/DistinctTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/DistinctTests.cs new file mode 100644 index 00000000000000..f30c58a5d97ab5 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/DistinctTests.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class DistinctTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Distinct(null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 1, 1, 1, 2, 2, 2, 2, 2 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8, 2, 4, 8, 2 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.Distinct(), + source.Distinct()); + + await AssertEqual( + values.Distinct(OddEvenComparer), + source.Distinct(OddEvenComparer)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.Distinct().WithCancellation(new CancellationToken(true))); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(source.Distinct()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ElementAtAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ElementAtAsyncTests.cs new file mode 100644 index 00000000000000..2283a775f5c303 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ElementAtAsyncTests.cs @@ -0,0 +1,109 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ElementAtAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ElementAtAsync((IAsyncEnumerable)null, 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ElementAtAsync((IAsyncEnumerable)null, new Index(0))); + } + + [Fact] + public async Task OutOfRange_Throws() + { + foreach (int length in new[] { 0, 1, 2, 10 }) + { + int[] values = Enumerable.Range(42, length).ToArray(); + + await Assert.ThrowsAsync(async () => await values.ToAsyncEnumerable().ElementAtAsync(-1)); + await Assert.ThrowsAsync(async () => await values.ToAsyncEnumerable().ElementAtAsync(length)); + + await Assert.ThrowsAsync(async () => await values.ToAsyncEnumerable().ElementAtAsync(new Index(length))); + await Assert.ThrowsAsync(async () => await values.ToAsyncEnumerable().ElementAtAsync(new Index(0, fromEnd: true))); + } + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + for (int i = 0; i < values.Length; i++) + { + Assert.Equal( + values.ElementAt(i), + await source.ElementAtAsync(i)); + +#if NET + Assert.Equal( + values.ElementAt(new Index(i)), + await source.ElementAtAsync(new Index(i))); + + Assert.Equal( + values.ElementAt(new Index(values.Length - i, fromEnd: true)), + await source.ElementAtAsync(new Index(values.Length - i, fromEnd: true))); +#endif + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => await source.ElementAtAsync(1, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await source.ElementAtAsync(new Index(1), new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await source.ElementAtAsync(new Index(1, fromEnd: true), new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtAsync(0); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtAsync(3); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtAsync(new Index(0)); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtAsync(new Index(3)); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtAsync(new Index(1, fromEnd: true)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ElementAtOrDefaultAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ElementAtOrDefaultAsyncTests.cs new file mode 100644 index 00000000000000..3cea9ca751cf48 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ElementAtOrDefaultAsyncTests.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ElementAtOrDefaultAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ElementAtOrDefaultAsync((IAsyncEnumerable)null, 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ElementAtOrDefaultAsync((IAsyncEnumerable)null, new Index(0))); + } + + [Fact] + public async Task OutOfRange_ReturnsDefault() + { + foreach (int length in new[] { 0, 1, 2, 10 }) + { + int[] values = Enumerable.Range(42, length).ToArray(); + + Assert.Equal(0, await values.ToAsyncEnumerable().ElementAtOrDefaultAsync(-1)); + Assert.Equal(0, await values.ToAsyncEnumerable().ElementAtOrDefaultAsync(length)); + Assert.Equal(0, await values.ToAsyncEnumerable().ElementAtOrDefaultAsync(new Index(length))); + Assert.Equal(0, await values.ToAsyncEnumerable().ElementAtOrDefaultAsync(new Index(100, fromEnd: true))); + + Assert.Null(await values.Select(i => (int?)i).ToAsyncEnumerable().ElementAtOrDefaultAsync(-1)); + Assert.Null(await values.Select(i => (int?)i).ToAsyncEnumerable().ElementAtOrDefaultAsync(length)); + Assert.Null(await values.Select(i => (int?)i).ToAsyncEnumerable().ElementAtOrDefaultAsync(new Index(length))); + Assert.Null(await values.Select(i => (int?)i).ToAsyncEnumerable().ElementAtOrDefaultAsync(new Index(100, fromEnd: true))); + + Assert.Null(await values.Select(i => i.ToString()).ToAsyncEnumerable().ElementAtOrDefaultAsync(-1)); + Assert.Null(await values.Select(i => i.ToString()).ToAsyncEnumerable().ElementAtOrDefaultAsync(length)); + Assert.Null(await values.Select(i => i.ToString()).ToAsyncEnumerable().ElementAtOrDefaultAsync(new Index(length))); + Assert.Null(await values.Select(i => i.ToString()).ToAsyncEnumerable().ElementAtOrDefaultAsync(new Index(100, fromEnd: true))); + } + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + for (int i = 0; i < values.Length; i++) + { + Assert.Equal( + values.ElementAtOrDefault(i), + await source.ElementAtOrDefaultAsync(i)); + +#if NET + Assert.Equal( + values.ElementAtOrDefault(new Index(i)), + await source.ElementAtOrDefaultAsync(new Index(i))); + + Assert.Equal( + values.ElementAtOrDefault(new Index(values.Length - i, fromEnd: true)), + await source.ElementAtOrDefaultAsync(new Index(values.Length - i, fromEnd: true))); +#endif + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => await source.ElementAtOrDefaultAsync(1, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await source.ElementAtOrDefaultAsync(new Index(1), new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await source.ElementAtOrDefaultAsync(new Index(1, fromEnd: true), new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtOrDefaultAsync(0); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtOrDefaultAsync(3); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtOrDefaultAsync(new Index(0)); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtOrDefaultAsync(new Index(3)); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.ElementAtOrDefaultAsync(new Index(1, fromEnd: true)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/EmptyTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/EmptyTests.cs new file mode 100644 index 00000000000000..00c5bfd82557bf --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/EmptyTests.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class EmptyTests : AsyncEnumerableTests + { + [Fact] + public void Empty_Idempotent() + { + IAsyncEnumerable ae = AsyncEnumerable.Empty(); + Assert.NotNull(ae); + Assert.Same(ae, AsyncEnumerable.Empty()); + Assert.NotSame(ae, AsyncEnumerable.Empty()); + + IAsyncEnumerator e = ae.GetAsyncEnumerator(default); + Assert.Same(e, ae.GetAsyncEnumerator(default)); + Assert.Same(e, ae.GetAsyncEnumerator(new CancellationToken(true))); + } + + [Fact] + public void Empty_ContainsZeroElements() + { + IAsyncEnumerator e = AsyncEnumerable.Empty().GetAsyncEnumerator(default); + + for (int i = 0; i < 2; i++) + { + ValueTask mn = e.MoveNextAsync(); + Assert.True(mn.IsCompleted); + Assert.False(mn.Result); + Assert.Equal(0, e.Current); // implementation detail + } + + ValueTask d = e.DisposeAsync(); + Assert.True(d.IsCompleted); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ExceptByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ExceptByTests.cs new file mode 100644 index 00000000000000..b455f76f33492f --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ExceptByTests.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ExceptByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.ExceptBy((IAsyncEnumerable)null, AsyncEnumerable.Empty(), x => x.ToString())); + AssertExtensions.Throws("second", () => AsyncEnumerable.ExceptBy(AsyncEnumerable.Empty(), null, x => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ExceptBy(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null)); + + AssertExtensions.Throws("first", () => AsyncEnumerable.ExceptBy((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (x, ct) => x.ToString())); + AssertExtensions.Throws("second", () => AsyncEnumerable.ExceptBy(AsyncEnumerable.Empty(), null, async (x, ct) => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ExceptBy(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null)); + } + +#if NET + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] firstInts, int[] second) + { + string[] first = firstInts.Select(x => x.ToString()).ToArray(); + + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.ExceptBy(second, int.Parse), + firstSource.ExceptBy(secondSource, int.Parse)); + await AssertEqual( + first.ExceptBy(second, int.Parse, OddEvenComparer), + firstSource.ExceptBy(secondSource, int.Parse, OddEvenComparer)); + + await AssertEqual( + first.ExceptBy(second, int.Parse), + firstSource.ExceptBy(secondSource, async (x, ct) => int.Parse(x))); + + await AssertEqual( + first.ExceptBy(second, int.Parse, OddEvenComparer), + firstSource.ExceptBy(secondSource, async (x, ct) => int.Parse(x), OddEvenComparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(1, 3, 5); + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.ExceptBy(second, x => x).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(first.ExceptBy(second, x => + { + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(first.ExceptBy(second, async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable first = CreateSource(2, 4, 8, 16, 32, 64).Track(); + TrackingAsyncEnumerable second = CreateSource(1, 3, 5).Track(); + int funcCount = 0; + await ConsumeAsync(useAsync ? + first.ExceptBy(second, async (x, ct) => + { + funcCount++; + return x; + }) : + first.ExceptBy(second, x => + { + funcCount++; + return x; + })); + Assert.Equal(7, first.MoveNextAsyncCount); + Assert.Equal(6, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + Assert.Equal(6, funcCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ExceptTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ExceptTests.cs new file mode 100644 index 00000000000000..b79ad34dda52ca --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ExceptTests.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ExceptTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.Except(null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.Except(AsyncEnumerable.Empty(), null)); + } + + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] first, int[] second) + { + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.Except(second), + firstSource.Except(secondSource)); + + await AssertEqual( + second.Except(first), + secondSource.Except(firstSource)); + + await AssertEqual( + first.Except(second, OddEvenComparer), + firstSource.Except(secondSource, OddEvenComparer)); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(1, 3, 5); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.Except(second).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable first = CreateSource(2, 4, 8, 16).Track(); + TrackingAsyncEnumerable second = CreateSource(1, 3, 5).Track(); + await ConsumeAsync(first.Except(second)); + + Assert.Equal(5, first.MoveNextAsyncCount); + Assert.Equal(4, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/FirstAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/FirstAsyncTests.cs new file mode 100644 index 00000000000000..7083542cb7e261 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/FirstAsyncTests.cs @@ -0,0 +1,119 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class FirstAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstAsync(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstAsync(null, async (i, ct) => i % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.FirstAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.FirstAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Fact] + public async Task EmptyInputs_Throws() + { + ValueTask first; + + first = AsyncEnumerable.Empty().FirstAsync(); + await Assert.ThrowsAsync(async () => await first); + + first = AsyncEnumerable.Empty().FirstAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = AsyncEnumerable.Empty().FirstAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = new int[] { 1, 3, 5 }.ToAsyncEnumerable().FirstAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = new int[] { 1, 3, 5 }.ToAsyncEnumerable().FirstAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.First(), + await source.FirstAsync()); + + Func predicate = i => i == values.Last(); + + Assert.Equal( + values.First(predicate), + await source.FirstAsync(predicate)); + + Assert.Equal( + values.First(predicate), + await source.FirstAsync(async (i, ct) => predicate(i))); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.FirstAsync(new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.FirstAsync(x => + { + cts.Cancel(); + return x > 32; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.FirstAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x > 32; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await source.FirstAsync(); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.FirstAsync(i => i == 8); + Assert.Equal(3, source.MoveNextAsyncCount); + Assert.Equal(3, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.FirstAsync(async (i, ct) => i == 16); + Assert.Equal(4, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/FirstOrDefaultAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/FirstOrDefaultAsyncTests.cs new file mode 100644 index 00000000000000..755e879bebbec9 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/FirstOrDefaultAsyncTests.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class FirstOrDefaultAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstOrDefaultAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstOrDefaultAsync(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstOrDefaultAsync(null, async (i, ct) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstOrDefaultAsync(null, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstOrDefaultAsync(null, i => i % 2 == 0, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.FirstOrDefaultAsync(null, async (i, ct) => i % 2 == 0, 42)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.FirstOrDefaultAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.FirstOrDefaultAsync(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.FirstOrDefaultAsync(AsyncEnumerable.Empty(), (Func)null, 42)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.FirstOrDefaultAsync(AsyncEnumerable.Empty(), (Func>)null, 42)); + } + + [Fact] + public async Task EmptyInputs_DefaultValueReturned() + { + Assert.Equal(0, await AsyncEnumerable.Empty().FirstOrDefaultAsync()); + Assert.Equal(42, await AsyncEnumerable.Empty().FirstOrDefaultAsync(42)); + Assert.Equal(0, await AsyncEnumerable.Empty().FirstOrDefaultAsync(i => i % 2 == 0)); + Assert.Equal(42, await AsyncEnumerable.Empty().FirstOrDefaultAsync(i => i % 2 == 0, 42)); + Assert.Equal(0, await AsyncEnumerable.Empty().FirstOrDefaultAsync(async (i, ct) => i % 2 == 0)); + Assert.Equal(42, await AsyncEnumerable.Empty().FirstOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + + IAsyncEnumerable source = new int[] { 1, 3, 5 }.ToAsyncEnumerable(); + Assert.Equal(0, await source.FirstOrDefaultAsync(i => i % 2 == 0)); + Assert.Equal(42, await source.FirstOrDefaultAsync(i => i % 2 == 0, 42)); + Assert.Equal(0, await source.FirstOrDefaultAsync(async (i, ct) => i % 2 == 0)); + Assert.Equal(42, await source.FirstOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.FirstOrDefault(), + await source.FirstOrDefaultAsync()); + + Func predicate = i => i == values.Last(); + + Assert.Equal( + values.FirstOrDefault(predicate), + await source.FirstOrDefaultAsync(predicate)); + + Assert.Equal( + values.FirstOrDefault(predicate), + await source.FirstOrDefaultAsync(async (i, ct) => predicate(i))); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.FirstOrDefaultAsync(new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.FirstOrDefaultAsync(x => + { + cts.Cancel(); + return x > 32; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.FirstOrDefaultAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x > 32; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(s => s.FirstOrDefaultAsync()); + await Validate(s => s.FirstOrDefaultAsync(42)); + await Validate(s => s.FirstOrDefaultAsync(i => i % 2 == 0)); + await Validate(s => s.FirstOrDefaultAsync(i => i % 2 == 0, 42)); + await Validate(s => s.FirstOrDefaultAsync(async (i, ct) => i % 2 == 0)); + await Validate(s => s.FirstOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + + static async Task Validate(Func, ValueTask> func) + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await func(source); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/GroupByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/GroupByTests.cs new file mode 100644 index 00000000000000..a908913f80a018 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/GroupByTests.cs @@ -0,0 +1,320 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class GroupByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, s => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, async (s, ct) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, s => s, s => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, async (s, ct) => s, async (s, ct) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, s => s, (s, group) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, async (s, ct) => s, async (s, group, ct) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, s => s, s => s, (s, group) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.GroupBy((IAsyncEnumerable)null, async (s, ct) => s, async (s, ct) => s, async (s, group, ct) => s)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func)null, s => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func>)null, async (s, ct) => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func)null, (s, group) => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func>)null, async (s, group, ct) => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func)null, s => s, (s, group) => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), (Func>)null, async (s, ct) => s, async (s, group, ct) => s)); + + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), s => s, (Func)null)); + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), async (s, ct) => s, (Func>)null)); + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), s => s, (Func)null, (s, group) => s)); + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), async (s, ct) => s, (Func>)null, async (s, group, ct) => s)); + + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), s => s, (Func, string>)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), async (s, ct) => s, (Func, CancellationToken, ValueTask>)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), s => s, s => s, (Func, string>)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupBy(AsyncEnumerable.Empty(), async (s, ct) => s, async (s, ct) => s, (Func, CancellationToken, ValueTask>)null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable_String() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 1000 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IEqualityComparer comparer in new[] { null, EqualityComparer.Default, OddEvenComparer }) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.GroupBy(s => s.Length, comparer), + source.GroupBy(s => s.Length, comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, comparer), + source.GroupBy(async (s, ct) => s.Length, comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, s => s.Length > 0 ? s[0] : ' ', comparer), + source.GroupBy(s => s.Length, s => s.Length > 0 ? s[0] : ' ', comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, s => s.Length > 0 ? s[0] : ' ', comparer), + source.GroupBy(async (s, ct) => s.Length, async (s, ct) => s.Length > 0 ? s[0] : ' ', comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, (key, group) => key.ToString() + string.Concat(group), comparer), + source.GroupBy(s => s.Length, (key, group) => key.ToString() + string.Concat(group), comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, (key, group) => key.ToString() + string.Concat(group), comparer), + source.GroupBy(async (s, ct) => s.Length, async (key, group, ct) => key.ToString() + string.Concat(group), comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, s => s.Length > 0 ? s.Substring(1) : "", (key, group) => key.ToString() + string.Concat(group), comparer), + source.GroupBy(s => s.Length, s => s.Length > 0 ? s.Substring(1) : "", (key, group) => key.ToString() + string.Concat(group), comparer)); + + await AssertEqual( + values.GroupBy(s => s.Length, s => s.Length > 0 ? s.Substring(1) : "", (key, group) => key.ToString() + string.Concat(group), comparer), + source.GroupBy(async (s, ct) => s.Length, async (s, ct) => s.Length > 0 ? s.Substring(1) : "", async (key, group, ct) => key.ToString() + string.Concat(group), comparer)); + } + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupBy(s => + { + cts.Cancel(); + return s; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupBy(async (s, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return s; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupBy(s => s, s => + { + cts.Cancel(); + return s; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupBy(async (s, ct) => s, async (s, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return s; + }).WithCancellation(cts.Token)); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int keySelectorCount, elementSelectorCount, resultSelectorCount; + + foreach (bool useAsync in TrueFalseBools) + { + keySelectorCount = 0; + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(useAsync ? + source.GroupBy(async (i, ct) => + { + keySelectorCount++; + return i % 2; + }) : + source.GroupBy(i => + { + keySelectorCount++; + return i; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = elementSelectorCount = 0; + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(useAsync ? + source.GroupBy(async (i, ct) => + { + keySelectorCount++; + return i % 2; + }, async (i, ct) => + { + elementSelectorCount++; + return i; + }) : + source.GroupBy(i => + { + keySelectorCount++; + return i; + }, i => + { + elementSelectorCount++; + return i; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(4, elementSelectorCount); + + keySelectorCount = resultSelectorCount = 0; + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(useAsync ? + source.GroupBy(async (i, ct) => + { + keySelectorCount++; + return i % 2; + }, async (key, group, ct) => + { + resultSelectorCount++; + return key; + }) : + source.GroupBy(i => + { + keySelectorCount++; + return i % 2; + }, (key, group) => + { + resultSelectorCount++; + return key; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(2, resultSelectorCount); + + keySelectorCount = elementSelectorCount = resultSelectorCount = 0; + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(useAsync ? + source.GroupBy(async (i, ct) => + { + keySelectorCount++; + return i % 2; + }, async (i, ct) => + { + elementSelectorCount++; + return i; + }, async (key, group, ct) => + { + resultSelectorCount++; + return key; + }) : + source.GroupBy(i => + { + keySelectorCount++; + return i % 2; + }, i => + { + elementSelectorCount++; + return i; + }, (key, group) => + { + resultSelectorCount++; + return key; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(4, elementSelectorCount); + Assert.Equal(2, resultSelectorCount); + } + } + + [Fact] + public async Task IGrouping_ImplementsIList() + { + List> result = await AsyncEnumerable.Range(0, 100).GroupBy(i => i % 2).ToListAsync(); + foreach (IGrouping group in result) + { + IList list = Assert.IsAssignableFrom>(group); + + Assert.Equal(50, list.Count); + Assert.True(list.IsReadOnly); + + if (group.Key == 0) + { + Assert.Equal(0, list[0]); + + Assert.True(list.Contains(0)); + Assert.True(list.Contains(98)); + Assert.False(list.Contains(1)); + Assert.False(list.Contains(99)); + + Assert.Equal(0, list.IndexOf(0)); + Assert.Equal(49, list.IndexOf(98)); + Assert.Equal(-1, list.IndexOf(99)); + } + else + { + Assert.Equal(1, list[0]); + Assert.True(list.Contains(1)); + Assert.True(list.Contains(99)); + Assert.False(list.Contains(2)); + Assert.False(list.Contains(98)); + + Assert.Equal(0, list.IndexOf(1)); + Assert.Equal(49, list.IndexOf(99)); + Assert.Equal(-1, list.IndexOf(98)); + } + for (int i = 0; i < list.Count - 1; i++) + { + Assert.Equal(list[i], list[i + 1] - 2); + } + AssertExtensions.Throws("index", () => list[50]); + + int[] ints = new int[52]; + list.CopyTo(ints, 1); + Assert.Equal(0, ints[0]); + Assert.Equal(list[0], ints[1]); + Assert.Equal(list[49], ints[50]); + Assert.Equal(0, ints[51]); + + Assert.Throws(() => list.Add(0)); + Assert.Throws(() => list.Clear()); + Assert.Throws(() => list.Insert(0, 0)); + Assert.Throws(() => list.Remove(0)); + Assert.Throws(() => list.RemoveAt(0)); + Assert.Throws(() => list[0] = 0); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs new file mode 100644 index 00000000000000..c6cb865ca61e07 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/GroupJoinTests.cs @@ -0,0 +1,164 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class GroupJoinTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("outer", () => AsyncEnumerable.GroupJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, (Func)null, (outer, inner) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, inner => inner, (Func, string>)null)); + + AssertExtensions.Throws("outer", () => AsyncEnumerable.GroupJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, (Func>)null, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.GroupJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, (Func, CancellationToken, ValueTask>)null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable_String() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 1000 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.GroupJoin(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + string.Concat(inner)), + source.GroupJoin(source, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + string.Concat(inner))); + + await AssertEqual( + values.GroupJoin(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + string.Concat(inner)), + source.GroupJoin(source, async (s, ct) => s.Length > 0 ? s[0] : ' ', async (s, ct) => s.Length > 1 ? s[1] : ' ', async (outer, inner, ct) => outer + string.Concat(inner))); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupJoin(source, outer => + { + cts.Cancel(); + return outer; + }, + inner => + { + return inner; + }, + (outer, inner) => + { + return outer + inner.Sum(); + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupJoin(source, + async (outer, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner.Sum(); + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupJoin(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner.Sum(); + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.GroupJoin(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer + inner.Sum(); + }).WithCancellation(cts.Token)); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable outer, inner; + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.GroupJoin(inner, outer => outer, inner => inner, (outer, inner) => outer + inner.Sum())); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.GroupJoin(inner, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner.Sum())); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/IndexTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/IndexTests.cs new file mode 100644 index 00000000000000..3c3c081bd5a16e --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/IndexTests.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class IndexTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Index(null)); + } + +#if NET + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.Index(), + source.Index()); + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach ((int Index, int Item) item in source.Index().WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(source.Index()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/IntersectByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/IntersectByTests.cs new file mode 100644 index 00000000000000..0d60650ae80b79 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/IntersectByTests.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class IntersectByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.IntersectBy((IAsyncEnumerable)null, AsyncEnumerable.Empty(), x => x.ToString())); + AssertExtensions.Throws("second", () => AsyncEnumerable.IntersectBy(AsyncEnumerable.Empty(), null, x => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.IntersectBy(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null)); + + AssertExtensions.Throws("first", () => AsyncEnumerable.IntersectBy((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (x, ct) => x.ToString())); + AssertExtensions.Throws("second", () => AsyncEnumerable.IntersectBy(AsyncEnumerable.Empty(), null, async (x, ct) => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.IntersectBy(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null)); + } + +#if NET + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] firstInts, int[] second) + { + string[] first = firstInts.Select(x => x.ToString()).ToArray(); + + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.IntersectBy(second, int.Parse), + firstSource.IntersectBy(secondSource, int.Parse)); + + await AssertEqual( + first.IntersectBy(second, int.Parse, OddEvenComparer), + firstSource.IntersectBy(secondSource, int.Parse, OddEvenComparer)); + + await AssertEqual( + first.IntersectBy(second, int.Parse), + firstSource.IntersectBy(secondSource, async (x, ct) => int.Parse(x))); + + await AssertEqual( + first.IntersectBy(second, int.Parse, OddEvenComparer), + firstSource.IntersectBy(secondSource, async (x, ct) => int.Parse(x), OddEvenComparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(2, 8, 32); + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.IntersectBy(second, x => x).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(first.IntersectBy(second, x => + { + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(first.IntersectBy(second, async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable first = CreateSource(2, 4, 8, 16, 32, 64).Track(); + TrackingAsyncEnumerable second = CreateSource(1, 3, 5).Track(); + int funcCount = 0; + await ConsumeAsync(useAsync ? + first.IntersectBy(second, async (x, ct) => + { + funcCount++; + return x; + }) : + first.IntersectBy(second, x => + { + funcCount++; + return x; + })); + Assert.Equal(7, first.MoveNextAsyncCount); + Assert.Equal(6, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + Assert.Equal(6, funcCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/IntersectTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/IntersectTests.cs new file mode 100644 index 00000000000000..423ec8a5a7e741 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/IntersectTests.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class IntersectTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.Intersect(null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.Intersect(AsyncEnumerable.Empty(), null)); + } + + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] first, int[] second) + { + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.Intersect(second), + firstSource.Intersect(secondSource)); + + await AssertEqual( + second.Intersect(first), + secondSource.Intersect(firstSource)); + + await AssertEqual( + first.Intersect(second, OddEvenComparer), + firstSource.Intersect(secondSource, OddEvenComparer)); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(2, 5, 6, 7); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.Intersect(second).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable first = CreateSource(2, 4, 8, 16).Track(); + TrackingAsyncEnumerable second = CreateSource(1, 3, 5).Track(); + await ConsumeAsync(first.Intersect(second)); + + Assert.Equal(5, first.MoveNextAsyncCount); + Assert.Equal(4, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/JoinTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/JoinTests.cs new file mode 100644 index 00000000000000..5863f7fb2c4ef9 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/JoinTests.cs @@ -0,0 +1,164 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class JoinTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("outer", () => AsyncEnumerable.Join((IAsyncEnumerable)null, AsyncEnumerable.Empty(), outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, (Func)null, (outer, inner) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, inner => inner, (Func)null)); + + AssertExtensions.Throws("outer", () => AsyncEnumerable.Join((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, (Func>)null, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.Join(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, (Func>)null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable_String() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 1000 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.Join(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner), + source.Join(source, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner)); + + await AssertEqual( + values.Join(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner), + source.Join(source, async (s, ct) => s.Length > 0 ? s[0] : ' ', async (s, ct) => s.Length > 1 ? s[1] : ' ', async (outer, inner, ct) => outer + inner)); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.Join(source, outer => + { + cts.Cancel(); + return outer; + }, + inner => + { + return inner; + }, + (outer, inner) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.Join(source, + async (outer, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.Join(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.Join(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer + inner; + }).WithCancellation(cts.Token)); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable outer, inner; + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.Join(inner, outer => outer, inner => inner, (outer, inner) => outer + inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.Join(inner, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/LastAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/LastAsyncTests.cs new file mode 100644 index 00000000000000..06736a674deba5 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/LastAsyncTests.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class LastAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.LastAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastAsync(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastAsync(null, async (i, ct) => i % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LastAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LastAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Fact] + public async Task EmptyInputs_Throws() + { + ValueTask first; + + first = AsyncEnumerable.Empty().LastAsync(); + await Assert.ThrowsAsync(async () => await first); + + first = AsyncEnumerable.Empty().LastAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = AsyncEnumerable.Empty().LastAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = new int[] { 1, 3, 5 }.ToAsyncEnumerable().LastAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = new int[] { 1, 3, 5 }.ToAsyncEnumerable().LastAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.Last(), + await source.LastAsync()); + + Func predicate = i => i < 5; + + Assert.Equal( + values.Last(predicate), + await source.LastAsync(predicate)); + + Assert.Equal( + values.Last(predicate), + await source.LastAsync(async (i, ct) => predicate(i))); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.LastAsync(new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.LastAsync(x => + { + cts.Cancel(); + return x > 32; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.LastAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x > 32; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int predicateCount; + + source = CreateSource(2, 4, 8, 16).Track(); + await source.LastAsync(); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + predicateCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.LastAsync(i => + { + predicateCount++; + return i == 8; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.LastAsync(async (i, ct) => + { + predicateCount++; + return i == 16; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/LastOrDefaultAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/LastOrDefaultAsyncTests.cs new file mode 100644 index 00000000000000..e49facb6f702a7 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/LastOrDefaultAsyncTests.cs @@ -0,0 +1,130 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class LastOrDefaultAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.LastOrDefaultAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastOrDefaultAsync(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastOrDefaultAsync(null, async (i, ct) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastOrDefaultAsync(null, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastOrDefaultAsync(null, i => i % 2 == 0, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.LastOrDefaultAsync(null, async (i, ct) => i % 2 == 0, 42)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LastOrDefaultAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LastOrDefaultAsync(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LastOrDefaultAsync(AsyncEnumerable.Empty(), (Func)null, 42)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.LastOrDefaultAsync(AsyncEnumerable.Empty(), (Func>)null, 42)); + } + + [Fact] + public async Task EmptyInputs_DefaultValueReturned() + { + Assert.Equal(0, await AsyncEnumerable.Empty().LastOrDefaultAsync()); + Assert.Equal(42, await AsyncEnumerable.Empty().LastOrDefaultAsync(42)); + Assert.Equal(0, await AsyncEnumerable.Empty().LastOrDefaultAsync(i => i % 2 == 0)); + Assert.Equal(42, await AsyncEnumerable.Empty().LastOrDefaultAsync(i => i % 2 == 0, 42)); + Assert.Equal(0, await AsyncEnumerable.Empty().LastOrDefaultAsync(async (i, ct) => i % 2 == 0)); + Assert.Equal(42, await AsyncEnumerable.Empty().LastOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + + IAsyncEnumerable source = new int[] { 1, 3, 5 }.ToAsyncEnumerable(); + Assert.Equal(0, await source.LastOrDefaultAsync(i => i % 2 == 0)); + Assert.Equal(42, await source.LastOrDefaultAsync(i => i % 2 == 0, 42)); + Assert.Equal(0, await source.LastOrDefaultAsync(async (i, ct) => i % 2 == 0)); + Assert.Equal(42, await source.LastOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.LastOrDefault(), + await source.LastOrDefaultAsync()); + + Func predicate = i => i < 5; + + Assert.Equal( + values.LastOrDefault(predicate), + await source.LastOrDefaultAsync(predicate)); + + Assert.Equal( + values.LastOrDefault(predicate), + await source.LastOrDefaultAsync(async (i, ct) => predicate(i))); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.LastOrDefaultAsync(new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.LastOrDefaultAsync(x => + { + cts.Cancel(); + return x > 32; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.LastOrDefaultAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x > 32; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int predicateCount; + + source = CreateSource(2, 4, 8, 16).Track(); + await source.LastOrDefaultAsync(); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + predicateCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.LastOrDefaultAsync(i => + { + predicateCount++; + return i == 8; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.LastOrDefaultAsync(async (i, ct) => + { + predicateCount++; + return i == 16; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/LeftJoinTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/LeftJoinTests.cs new file mode 100644 index 00000000000000..334b8ba202b41a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/LeftJoinTests.cs @@ -0,0 +1,166 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class LeftJoinTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("outer", () => AsyncEnumerable.LeftJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, (Func)null, (outer, inner) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, inner => inner, (Func)null)); + + AssertExtensions.Throws("outer", () => AsyncEnumerable.LeftJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, (Func>)null, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.LeftJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, (Func>)null)); + } + +#if NET + [Fact] + public async Task VariousValues_MatchesEnumerable_String() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 1000 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.LeftJoin(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner), + source.LeftJoin(source, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner)); + + await AssertEqual( + values.LeftJoin(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner), + source.LeftJoin(source, async (s, ct) => s.Length > 0 ? s[0] : ' ', async (s, ct) => s.Length > 1 ? s[1] : ' ', async (outer, inner, ct) => outer + inner)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.LeftJoin(source, outer => + { + cts.Cancel(); + return outer; + }, + inner => + { + return inner; + }, + (outer, inner) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.LeftJoin(source, + async (outer, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.LeftJoin(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.LeftJoin(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer + inner; + }).WithCancellation(cts.Token)); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable outer, inner; + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.LeftJoin(inner, outer => outer, inner => inner, (outer, inner) => outer + inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.LeftJoin(inner, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/MaxAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/MaxAsyncTests.cs new file mode 100644 index 00000000000000..001e4b1f0aef30 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/MaxAsyncTests.cs @@ -0,0 +1,175 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class MaxAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxAsync(null, Comparer.Default, default)); + } + + [Fact] + public async Task EmptyInputs_NonNullable_Throws() + { + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + + Assert.Null(await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MaxAsync(AsyncEnumerable.Empty())); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 0 })] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -1000, 1000 })] + [InlineData(new int[] { -1, -2, -3 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + if (values.Length > 0) + { + Assert.Equal(values.Select(i => (int)i).Max(), await source.Select(i => (int)i).MaxAsync()); + Assert.Equal(values.Select(i => (long)i).Max(), await source.Select(i => (long)i).MaxAsync()); + Assert.Equal(values.Select(i => (float)i).Max(), await source.Select(i => (float)i).MaxAsync()); + Assert.Equal(values.Select(i => (double)i).Max(), await source.Select(i => (double)i).MaxAsync()); + Assert.Equal(values.Select(i => (decimal)i).Max(), await source.Select(i => (decimal)i).MaxAsync()); + +#if NET + Assert.Equal(values.Select(i => (int)i).Max(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (int)i).MaxAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (long)i).Max(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (long)i).MaxAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (float)i).Max(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (float)i).MaxAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (double)i).Max(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (double)i).MaxAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (decimal)i).Max(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (decimal)i).MaxAsync(Comparer.Create((x, y) => y.CompareTo(x)))); +#endif + + Assert.Equal(values.Select(i => TimeSpan.FromSeconds(i)).Max(), await source.Select(i => TimeSpan.FromSeconds(i)).MaxAsync()); + } + + Assert.Equal(values.Select(i => (int?)i).Max(), await source.Select(i => (int?)i).MaxAsync()); + Assert.Equal(values.Select(i => (long?)i).Max(), await source.Select(i => (long?)i).MaxAsync()); + Assert.Equal(values.Select(i => (float?)i).Max(), await source.Select(i => (float?)i).MaxAsync()); + Assert.Equal(values.Select(i => (double?)i).Max(), await source.Select(i => (double?)i).MaxAsync()); + Assert.Equal(values.Select(i => (decimal?)i).Max(), await source.Select(i => (decimal?)i).MaxAsync()); + +#if NET + Assert.Equal(values.Select(i => (int?)i).Max(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (int?)i).MaxAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (long?)i).Max(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (long?)i).MaxAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (float?)i).Max(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (float?)i).MaxAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (double?)i).Max(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (double?)i).MaxAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (decimal?)i).Max(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (decimal?)i).MaxAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); +#endif + + // With NaNs + foreach (double[] special in new double[][] { [double.NaN, double.NaN], [1.0, double.NaN], [double.NaN, 1.0] }) + { + Assert.Equal( + special.Select(d => (float)d).Concat(values.Select(i => (float)i)).Concat(special.Select(d => (float)d)).Max(), + await special.Select(d => (float)d).ToAsyncEnumerable().Concat(source.Select(i => (float)i)).Concat(special.Select(d => (float)d).ToAsyncEnumerable()).MaxAsync()); + Assert.Equal( + special.Concat(values.Select(i => (double)i)).Concat(special).Max(), + await special.ToAsyncEnumerable().Concat(source.Select(i => (double)i)).Concat(special.ToAsyncEnumerable()).MaxAsync()); + Assert.Equal( + special.Select(d => (float?)d).Concat(values.Select(i => (float?)i)).Concat(special.Select(d => (float?)d)).Max(), + await special.Select(d => (float?)d).ToAsyncEnumerable().Concat(source.Select(i => (float?)i)).Concat(special.Select(d => (float?)d).ToAsyncEnumerable()).MaxAsync()); + Assert.Equal( + special.Select(d => (double?)d).Concat(values.Select(i => (double?)i)).Concat(special.Select(d => (double?)d)).Max(), + await special.Select(d => (double?)d).ToAsyncEnumerable().Concat(source.Select(i => (double?)i)).Concat(special.Select(d => (double?)d).ToAsyncEnumerable()).MaxAsync()); + } + + // With nulls + Assert.Equal( + new float?[] { null, null }.Concat(values.Select(i => (float?)i)).Concat(new float?[] { null, null }).Max(), + await new float?[] { null, null }.ToAsyncEnumerable().Concat(source.Select(i => (float?)i)).Concat(new float?[] { null, null }.ToAsyncEnumerable()).MaxAsync()); + Assert.Equal( + new double?[] { null, null }.Concat(values.Select(i => (double?)i)).Concat(new double?[] { null, null }).Max(), + await new double?[] { null, null }.ToAsyncEnumerable().Concat(source.Select(i => (double?)i)).Concat(new double?[] { null, null }.ToAsyncEnumerable()).MaxAsync()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal)i).MaxAsync(null, new CancellationToken(true))); + + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int?)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long?)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float?)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double?)i).MaxAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal?)i).MaxAsync(null, new CancellationToken(true))); + + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => TimeSpan.FromSeconds(i)).MaxAsync(null, new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.Select(i => (int)i).MaxAsync()); + await Validate(source => source.Select(i => (long)i).MaxAsync()); + await Validate(source => source.Select(i => (float)i).MaxAsync()); + await Validate(source => source.Select(i => (double)i).MaxAsync()); + await Validate(source => source.Select(i => (decimal)i).MaxAsync()); + + await Validate(source => source.Select(i => (int?)i).MaxAsync()); + await Validate(source => source.Select(i => (long?)i).MaxAsync()); + await Validate(source => source.Select(i => (float?)i).MaxAsync()); + await Validate(source => source.Select(i => (double?)i).MaxAsync()); + await Validate(source => source.Select(i => (decimal?)i).MaxAsync()); + + await Validate(source => source.Select(i => TimeSpan.FromSeconds(i)).MaxAsync()); + + static async Task Validate(Func, ValueTask> factory) + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await factory(source); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).AppendException(new FormatException()).Track(); + await Assert.ThrowsAsync(async () => await factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/MaxByAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/MaxByAsyncTests.cs new file mode 100644 index 00000000000000..95307dd7bee933 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/MaxByAsyncTests.cs @@ -0,0 +1,153 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class MaxByAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxByAsync((IAsyncEnumerable)null, i => i * 2)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MaxByAsync((IAsyncEnumerable)null, async (i, ct) => i * 2)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.MaxByAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.MaxByAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Fact] + public async Task EmptyInputs_ThrowsForNonNullableTypes() + { + await AsyncEnumerable.Empty().MaxByAsync(i => i.Length); + await AsyncEnumerable.Empty().MaxByAsync(async (i, ct) => i.Length); + + await AsyncEnumerable.Empty().MaxByAsync(i => i); + await AsyncEnumerable.Empty().MaxByAsync(async (i, ct) => i); + + ValueTask result; + + result = AsyncEnumerable.Empty().MaxByAsync(i => i); + await Assert.ThrowsAsync(async () => await result); + + result = AsyncEnumerable.Empty().MaxByAsync(async (i, ct) => i); + await Assert.ThrowsAsync(async () => await result); + } + +#if NET + [Theory] + [InlineData(new int[] { 0 })] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 8, 2, 7, 3, 6, 4, 5 })] + [InlineData(new int[] { -1000, 1000 })] + [InlineData(new int[] { -1, -2, -3 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + foreach (IComparer comparer in new[] { null, Comparer.Default, Comparer.Create((x, y) => y.CompareTo(x)) }) + { + Assert.Equal( + values.MaxBy(i => -i, comparer), + await source.MaxByAsync(i => -i, comparer)); + + Assert.Equal( + values.MaxBy(i => -i, comparer), + await source.MaxByAsync(async (i, ct) => -i, comparer)); + } + + foreach (IComparer comparer in new IComparer[] { null, Comparer.Default, StringComparer.OrdinalIgnoreCase }) + { + Assert.Equal( + values.Select(i => i.ToString()).MaxBy(s => s.ToLower(), comparer), + await source.Select(i => i.ToString()).MaxByAsync(s => s.ToLower(), comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MaxBy(s => s.ToLower(), comparer), + await source.Select(i => i.ToString()).MaxByAsync(async (s, ct) => s.ToLower(), comparer)); + } + + foreach (IComparer comparer in new IComparer[] { null, Comparer.Default, StringComparer.OrdinalIgnoreCase }) + { + Assert.Equal( + values.Select(i => i.ToString()).MaxBy(s => null, comparer), + await source.Select(i => i.ToString()).MaxByAsync(s => null, comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MaxBy(s => s.CompareTo("3") < 0 ? null : s, comparer), + await source.Select(i => i.ToString()).MaxByAsync(s => s.CompareTo("3") < 0 ? null : s, comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MaxBy(s => null, comparer), + await source.Select(i => i.ToString()).MaxByAsync(async (s, ct) => null, comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MaxBy(s => s.CompareTo("3") < 0 ? null : s, comparer), + await source.Select(i => i.ToString()).MaxByAsync(async (s, ct) => s.CompareTo("3") < 0 ? null : s, comparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + TrackingAsyncEnumerable source = CreateSource(2, 4).Track(); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await source.MaxByAsync(i => + { + cts.Cancel(); + return i; + }, null, cts.Token); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await source.MaxByAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, null, cts.Token); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable source; + int keySelectorCount; + + keySelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await (useAsync ? + source.MaxByAsync(async (i, ct) => + { + keySelectorCount++; + return i; + }) : + source.MaxByAsync(i => + { + keySelectorCount++; + return i; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/MinAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/MinAsyncTests.cs new file mode 100644 index 00000000000000..6ebe3b585a8734 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/MinAsyncTests.cs @@ -0,0 +1,175 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class MinAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinAsync(null, Comparer.Default, default)); + } + + [Fact] + public async Task EmptyInputs_NonNullable_Throws() + { + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + + Assert.Null(await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + Assert.Null(await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + + await Assert.ThrowsAsync(async () => await AsyncEnumerable.MinAsync(AsyncEnumerable.Empty())); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 0 })] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -1000, 1000 })] + [InlineData(new int[] { -1, -2, -3 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + if (values.Length > 0) + { + Assert.Equal(values.Select(i => (int)i).Min(), await source.Select(i => (int)i).MinAsync()); + Assert.Equal(values.Select(i => (long)i).Min(), await source.Select(i => (long)i).MinAsync()); + Assert.Equal(values.Select(i => (float)i).Min(), await source.Select(i => (float)i).MinAsync()); + Assert.Equal(values.Select(i => (double)i).Min(), await source.Select(i => (double)i).MinAsync()); + Assert.Equal(values.Select(i => (decimal)i).Min(), await source.Select(i => (decimal)i).MinAsync()); + +#if NET + Assert.Equal(values.Select(i => (int)i).Min(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (int)i).MinAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (long)i).Min(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (long)i).MinAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (float)i).Min(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (float)i).MinAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (double)i).Min(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (double)i).MinAsync(Comparer.Create((x, y) => y.CompareTo(x)))); + Assert.Equal(values.Select(i => (decimal)i).Min(Comparer.Create((x, y) => y.CompareTo(x))), await source.Select(i => (decimal)i).MinAsync(Comparer.Create((x, y) => y.CompareTo(x)))); +#endif + + Assert.Equal(values.Select(i => TimeSpan.FromSeconds(i)).Min(), await source.Select(i => TimeSpan.FromSeconds(i)).MinAsync()); + } + + Assert.Equal(values.Select(i => (int?)i).Min(), await source.Select(i => (int?)i).MinAsync()); + Assert.Equal(values.Select(i => (long?)i).Min(), await source.Select(i => (long?)i).MinAsync()); + Assert.Equal(values.Select(i => (float?)i).Min(), await source.Select(i => (float?)i).MinAsync()); + Assert.Equal(values.Select(i => (double?)i).Min(), await source.Select(i => (double?)i).MinAsync()); + Assert.Equal(values.Select(i => (decimal?)i).Min(), await source.Select(i => (decimal?)i).MinAsync()); + +#if NET + Assert.Equal(values.Select(i => (int?)i).Min(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (int?)i).MinAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (long?)i).Min(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (long?)i).MinAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (float?)i).Min(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (float?)i).MinAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (double?)i).Min(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (double?)i).MinAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); + Assert.Equal(values.Select(i => (decimal?)i).Min(Comparer.Create((x, y) => Nullable.Compare(y, x))), await source.Select(i => (decimal?)i).MinAsync(Comparer.Create((x, y) => Nullable.Compare(y, x)))); +#endif + + // With NaNs + foreach (double[] special in new double[][] { [double.NaN, double.NaN], [1.0, double.NaN], [double.NaN, 1.0] }) + { + Assert.Equal( + special.Select(d => (float)d).Concat(values.Select(i => (float)i)).Concat(special.Select(d => (float)d)).Min(), + await special.Select(d => (float)d).ToAsyncEnumerable().Concat(source.Select(i => (float)i)).Concat(special.Select(d => (float)d).ToAsyncEnumerable()).MinAsync()); + Assert.Equal( + special.Concat(values.Select(i => (double)i)).Concat(special).Min(), + await special.ToAsyncEnumerable().Concat(source.Select(i => (double)i)).Concat(special.ToAsyncEnumerable()).MinAsync()); + Assert.Equal( + special.Select(d => (float?)d).Concat(values.Select(i => (float?)i)).Concat(special.Select(d => (float?)d)).Min(), + await special.Select(d => (float?)d).ToAsyncEnumerable().Concat(source.Select(i => (float?)i)).Concat(special.Select(d => (float?)d).ToAsyncEnumerable()).MinAsync()); + Assert.Equal( + special.Select(d => (double?)d).Concat(values.Select(i => (double?)i)).Concat(special.Select(d => (double?)d)).Min(), + await special.Select(d => (double?)d).ToAsyncEnumerable().Concat(source.Select(i => (double?)i)).Concat(special.Select(d => (double?)d).ToAsyncEnumerable()).MinAsync()); + } + + // With nulls + Assert.Equal( + new float?[] { null, null }.Concat(values.Select(i => (float?)i)).Concat(new float?[] { null, null }).Min(), + await new float?[] { null, null }.ToAsyncEnumerable().Concat(source.Select(i => (float?)i)).Concat(new float?[] { null, null }.ToAsyncEnumerable()).MinAsync()); + Assert.Equal( + new double?[] { null, null }.Concat(values.Select(i => (double?)i)).Concat(new double?[] { null, null }).Min(), + await new double?[] { null, null }.ToAsyncEnumerable().Concat(source.Select(i => (double?)i)).Concat(new double?[] { null, null }.ToAsyncEnumerable()).MinAsync()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal)i).MinAsync(null, new CancellationToken(true))); + + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int?)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long?)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float?)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double?)i).MinAsync(null, new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal?)i).MinAsync(null, new CancellationToken(true))); + + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => TimeSpan.FromSeconds(i)).MinAsync(null, new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.Select(i => (int)i).MinAsync()); + await Validate(source => source.Select(i => (long)i).MinAsync()); + await Validate(source => source.Select(i => (float)i).MinAsync()); + await Validate(source => source.Select(i => (double)i).MinAsync()); + await Validate(source => source.Select(i => (decimal)i).MinAsync()); + + await Validate(source => source.Select(i => (int?)i).MinAsync()); + await Validate(source => source.Select(i => (long?)i).MinAsync()); + await Validate(source => source.Select(i => (float?)i).MinAsync()); + await Validate(source => source.Select(i => (double?)i).MinAsync()); + await Validate(source => source.Select(i => (decimal?)i).MinAsync()); + + await Validate(source => source.Select(i => TimeSpan.FromSeconds(i)).MinAsync()); + + static async Task Validate(Func, ValueTask> factory) + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await factory(source); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).AppendException(new FormatException()).Track(); + await Assert.ThrowsAsync(async () => await factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/MinByAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/MinByAsyncTests.cs new file mode 100644 index 00000000000000..2d9eedcbe74cd1 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/MinByAsyncTests.cs @@ -0,0 +1,153 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class MinByAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.MinByAsync((IAsyncEnumerable)null, i => i * 2)); + AssertExtensions.Throws("source", () => AsyncEnumerable.MinByAsync((IAsyncEnumerable)null, async (i, ct) => i * 2)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.MinByAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.MinByAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Fact] + public async Task EmptyInputs_ThrowsForNonNullableTypes() + { + await AsyncEnumerable.Empty().MinByAsync(i => i.Length); + await AsyncEnumerable.Empty().MinByAsync(async (i, ct) => i.Length); + + await AsyncEnumerable.Empty().MinByAsync(i => i); + await AsyncEnumerable.Empty().MinByAsync(async (i, ct) => i); + + ValueTask result; + + result = AsyncEnumerable.Empty().MinByAsync(i => i); + await Assert.ThrowsAsync(async () => await result); + + result = AsyncEnumerable.Empty().MinByAsync(async (i, ct) => i); + await Assert.ThrowsAsync(async () => await result); + } + +#if NET + [Theory] + [InlineData(new int[] { 0 })] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 8, 2, 7, 3, 6, 4, 5 })] + [InlineData(new int[] { -1000, 1000 })] + [InlineData(new int[] { -1, -2, -3 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + foreach (IComparer comparer in new[] { null, Comparer.Default, Comparer.Create((x, y) => y.CompareTo(x)) }) + { + Assert.Equal( + values.MinBy(i => -i, comparer), + await source.MinByAsync(i => -i, comparer)); + + Assert.Equal( + values.MinBy(i => -i, comparer), + await source.MinByAsync(async (i, ct) => -i, comparer)); + } + + foreach (IComparer comparer in new IComparer[] { null, Comparer.Default, StringComparer.OrdinalIgnoreCase }) + { + Assert.Equal( + values.Select(i => i.ToString()).MinBy(s => s.ToLower(), comparer), + await source.Select(i => i.ToString()).MinByAsync(s => s.ToLower(), comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MinBy(s => s.ToLower(), comparer), + await source.Select(i => i.ToString()).MinByAsync(async (s, ct) => s.ToLower(), comparer)); + } + + foreach (IComparer comparer in new IComparer[] { null, Comparer.Default, StringComparer.OrdinalIgnoreCase }) + { + Assert.Equal( + values.Select(i => i.ToString()).MinBy(s => null, comparer), + await source.Select(i => i.ToString()).MinByAsync(s => null, comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MinBy(s => s.CompareTo("3") < 0 ? null : s, comparer), + await source.Select(i => i.ToString()).MinByAsync(s => s.CompareTo("3") < 0 ? null : s, comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MinBy(s => null, comparer), + await source.Select(i => i.ToString()).MinByAsync(async (s, ct) => null, comparer)); + + Assert.Equal( + values.Select(i => i.ToString()).MinBy(s => s.CompareTo("3") < 0 ? null : s, comparer), + await source.Select(i => i.ToString()).MinByAsync(async (s, ct) => s.CompareTo("3") < 0 ? null : s, comparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + TrackingAsyncEnumerable source = CreateSource(2, 4).Track(); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await source.MinByAsync(i => + { + cts.Cancel(); + return i; + }, null, cts.Token); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await source.MinByAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, null, cts.Token); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable source; + int keySelectorCount; + + keySelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await (useAsync ? + source.MinByAsync(async (i, ct) => + { + keySelectorCount++; + return i; + }) : + source.MinByAsync(i => + { + keySelectorCount++; + return i; + })); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/OfTypeTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/OfTypeTests.cs new file mode 100644 index 00000000000000..6bf667b12699ca --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/OfTypeTests.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class OfTypeTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.OfType(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.OfType(null)); + } + + [Fact] + public async Task Empty_ProducesEmpty() + { + await AssertEqual(AsyncEnumerable.Empty(), AsyncEnumerable.Empty().OfType()); + await AssertEqual(AsyncEnumerable.Empty(), AsyncEnumerable.Empty().OfType()); + } + + [Fact] + public async Task NullAndNonNull_SkipsNulls() + { + await AssertEqual(["2", "8"], CreateSource("2", null, "8", null).OfType()); + await AssertEqual(["2", "8"], CreateSource("2", null, "8", null).OfType()); + await AssertEqual([2, 8], CreateSource(2, null, 8, null).OfType()); + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource("2", null, "8", null); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (string item in source.OfType().WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, null, 8, 16).Track(); + await ConsumeAsync(source.OfType()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/OrderByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/OrderByTests.cs new file mode 100644 index 00000000000000..558bf4c1dfaf67 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/OrderByTests.cs @@ -0,0 +1,186 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class OrderByTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Order((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.OrderDescending((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.OrderBy((IAsyncEnumerable)null, i => i)); + AssertExtensions.Throws("source", () => AsyncEnumerable.OrderBy((IAsyncEnumerable)null, async (i, ct) => i)); + AssertExtensions.Throws("source", () => AsyncEnumerable.OrderByDescending((IAsyncEnumerable)null, i => i)); + AssertExtensions.Throws("source", () => AsyncEnumerable.OrderByDescending((IAsyncEnumerable)null, async (i, ct) => i)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.OrderBy(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.OrderBy(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.OrderByDescending(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.OrderByDescending(AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.ThenBy((IOrderedAsyncEnumerable)null, i => i)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ThenBy((IOrderedAsyncEnumerable)null, async (i, ct) => i)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ThenByDescending((IOrderedAsyncEnumerable)null, i => i)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ThenByDescending((IOrderedAsyncEnumerable)null, async (i, ct) => i)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ThenBy(AsyncEnumerable.Empty().Order(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ThenBy(AsyncEnumerable.Empty().Order(), (Func>)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ThenByDescending(AsyncEnumerable.Empty().Order(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ThenByDescending(AsyncEnumerable.Empty().Order(), (Func>)null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable_Int32() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 3, 4, 100, 1024 }) + { + foreach (IComparer comparer in new[] { null, Comparer.Default, Comparer.Create((x, y) => y.CompareTo(x)) }) + { + int[] ints = new int[length]; + FillRandom(rand, ints); + int[] copy = ints.ToArray(); + + foreach (IAsyncEnumerable source in CreateSources(ints)) + { +#if NET + await AssertEqual( + ints.Order(comparer), + source.Order(comparer)); + + await AssertEqual( + ints.OrderDescending(comparer), + source.OrderDescending(comparer)); +#endif + + await AssertEqual( + ints.OrderBy(i => i % 2 == 0 ? i : -1, comparer), + source.OrderBy(i => i % 2 == 0 ? i : -1, comparer)); + + await AssertEqual( + ints.OrderBy(i => i % 2 == 0 ? i : -1, comparer), + source.OrderBy(async (i, ct) => i % 2 == 0 ? i : -1, comparer)); + + await AssertEqual( + ints.OrderByDescending(i => i % 2 == 0 ? i : -1, comparer), + source.OrderByDescending(i => i % 2 == 0 ? i : -1, comparer)); + + await AssertEqual( + ints.OrderByDescending(i => i % 2 == 0 ? i : -1, comparer), + source.OrderByDescending(async (i, ct) => i % 2 == 0 ? i : -1, comparer)); + + Assert.Equal(copy, ints); + } + } + } + } + + [Fact] + public async Task VariousValues_MatchesEnumerable_String() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 3, 4, 100, 1024 }) + { + foreach (IComparer comparer in new IComparer[] { null, Comparer.Default, StringComparer.Ordinal, StringComparer.OrdinalIgnoreCase }) + { + string[] strings = new string[length]; + FillRandom(rand, strings); + + string[] copy = strings.ToArray(); + + foreach (IAsyncEnumerable source in CreateSources(strings)) + { +#if NET + await AssertEqual( + strings.Order(comparer), + source.Order(comparer)); + + await AssertEqual( + strings.OrderDescending(comparer), + source.OrderDescending(comparer)); +#endif + + await AssertEqual( + strings.OrderBy(s => s.Length), + source.OrderBy(s => s.Length)); + + await AssertEqual( + strings.OrderBy(s => s.Length), + source.OrderBy(async (s, ct) => s.Length)); + + await AssertEqual( + strings.OrderByDescending(s => s.Length), + source.OrderByDescending(s => s.Length)); + + await AssertEqual( + strings.OrderByDescending(s => s.Length), + source.OrderByDescending(async (s, ct) => s.Length)); + + await AssertEqual( + strings.OrderBy(s => s.Length).ThenBy(s => s.Length > 0 ? s[0] : ' '), + source.OrderBy(s => s.Length).ThenBy(s => s.Length > 0 ? s[0] : ' ')); + + await AssertEqual( + strings.OrderBy(s => s.Length).ThenBy(s => s.Length > 0 ? s[0] : ' '), + source.OrderBy(async (s, ct) => s.Length).ThenBy(async (s, ct) => s.Length > 0 ? s[0] : ' ')); + + await AssertEqual( + strings.OrderByDescending(s => s.Length).ThenByDescending(s => s.Length > 0 ? s[0] : ' '), + source.OrderByDescending(s => s.Length).ThenByDescending(s => s.Length > 0 ? s[0] : ' ')); + + await AssertEqual( + strings.OrderByDescending(s => s.Length).ThenByDescending(s => s.Length > 0 ? s[0] : ' '), + source.OrderByDescending(async (s, ct) => s.Length).ThenByDescending(async (s, ct) => s.Length > 0 ? s[0] : ' ')); + + Assert.Equal(copy, strings); + } + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Validate(source.OrderBy(i => i)); + await Validate(source.OrderBy(async (i, ct) => i)); + await Validate(source.OrderByDescending(i => i)); + await Validate(source.OrderByDescending(async (i, ct) => i)); + + static async Task Validate(IAsyncEnumerable source) + { + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.WithCancellation(new CancellationToken(true))); + }); + } + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.OrderBy(i => i)); + await Validate(source => source.OrderBy(async (i, ct) => i)); + await Validate(source => source.OrderByDescending(i => i)); + await Validate(source => source.OrderByDescending(async (i, ct) => i)); + + async Task Validate(Func, IAsyncEnumerable> factory) + { + TrackingAsyncEnumerable source = CreateSource(Enumerable.Range(0, 100).ToArray()).Track(); + await ConsumeAsync(factory(source)); + Assert.Equal(101, source.MoveNextAsyncCount); + Assert.Equal(100, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/PrependTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/PrependTests.cs new file mode 100644 index 00000000000000..6316954592b2e3 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/PrependTests.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class PrependTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Prepend(null, 42)); + } + +#if NET + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.Prepend(42), + source.Prepend(42)); + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in source.Prepend(42).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(source.Prepend(42)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/QueryComprehesionTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/QueryComprehesionTests.cs new file mode 100644 index 00000000000000..20f58ebfbdd6dc --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/QueryComprehesionTests.cs @@ -0,0 +1,241 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class QueryComprehensionTests : AsyncEnumerableTests + { + // Tests based on the C# specification section 12.20.3 Query expression translation + + /// 12.20.3.2 Query expressions with continuations + [Fact] + public async Task QueryExpressionsWithContinuations() + { + await AssertEqual( + from c in GetCities() + group c by c.State into g + select $"{g.Key}: {string.Join(", ", g.Select(c => c.Name))}", + + from c in GetCitiesAsync() + group c by c.State into g + select $"{g.Key}: {string.Join(", ", g.Select(c => c.Name))}"); + + await AssertEqual( + from g in (from c in GetCities() + group c by c.State) + select $"{g.Key}: {string.Join(", ", g.Select(c => c.Name))}", + + from g in (from c in GetCitiesAsync() + group c by c.State) + select $"{g.Key}: {string.Join(", ", g.Select(c => c.Name))}"); + } + + /// 12.20.3.3 Explicit range variable types + [Fact] + public async Task ExplicitRangeVariableTypes() + { + await AssertEqual( + from City c in GetCities() select c.Name, + from City c in GetCitiesAsync() select c.Name); + } + + // 12.20.3.5 From, let, where, join and orderby clauses + [Fact] + public async Task FromLetWhereJoinClauses() + { + await AssertEqual( + from c1 in GetCities() + from c2 in GetCities() + where c1.Name != c2.Name + select $"{c1.Name} => {c2.Name}", + + from c1 in GetCitiesAsync() + from c2 in GetCitiesAsync() + where c1.Name != c2.Name + select $"{c1.Name} => {c2.Name}"); + + await AssertEqual( + from c1 in GetCities() + orderby c1.Name + from c2 in GetCities() + orderby c2.Name descending + where c1.Name != c2.Name + select $"{c1.Name} => {c2.Name}", + + from c1 in GetCitiesAsync() + orderby c1.Name + from c2 in GetCities() + orderby c2.Name descending + where c1.Name != c2.Name + select $"{c1.Name} => {c2.Name}"); + + await AssertEqual( + from c1 in GetCities() + orderby c1.Name + from c2 in GetCities() + orderby c2.Name descending + where c1.Name != c2.Name + select $"{c1.Name} => {c2.Name}", + + from c1 in GetCitiesAsync() + orderby c1.Name + from c2 in GetCitiesAsync() + orderby c2.Name descending + where c1.Name != c2.Name + select $"{c1.Name} => {c2.Name}"); + + await AssertEqual( + from c1 in GetCities() + let c1Name = c1.Name + from c2 in GetCities() + let c2Name = c2.Name + where c1Name != c2Name + select $"{c1Name} => {c2Name}", + + from c1 in GetCitiesAsync() + let c1Name = c1.Name + from c2 in GetCitiesAsync() + let c2Name = c2.Name + where c1Name != c2Name + select $"{c1Name} => {c2Name}"); + + await AssertEqual( + from c1 in GetCities() + where c1.State == "MA" + select c1.Name, + + from c1 in GetCitiesAsync() + where c1.State == "MA" + select c1.Name); + + await AssertEqual( + from c1 in GetCities() + join c2 in GetCities() on c1.State equals c2.State + select c1.Name, + + from c1 in GetCitiesAsync() + join c2 in GetCitiesAsync() on c1.State equals c2.State + select c1.Name); + + await AssertEqual( + from c1 in GetCities() + join c2 in GetCities() on c1.State equals c2.State into g + from c in g + select c.Name, + + from c1 in GetCitiesAsync() + join c2 in GetCitiesAsync() on c1.State equals c2.State into g + from c in g + select c.Name); + } + + // 12.20.3.5 From, let, where, join and orderby clauses + [Fact] + public async Task OrderByClauses() + { + await AssertEqual( + from c in GetCities() + orderby c.State, c.Name descending + select c.Name, + + from c in GetCitiesAsync() + orderby c.State, c.Name descending + select c.Name); + } + + // 12.20.3.6 Select clauses + [Fact] + public async Task SelectClauses() + { + await AssertEqual( + from c in GetCities() + select c.Name, + + from c in GetCitiesAsync() + select c.Name); + } + + // 12.20.3.7 Group clauses + [Fact] + public async Task GroupClauses() + { + await AssertEqual( + from c in GetCities() + group c.Name by c.State, + + from c in GetCitiesAsync() + group c.Name by c.State); + } + + private record City(string Name, string State); + + private static async IAsyncEnumerable GetCitiesAsync() + { + foreach (City city in GetCities()) + { + await Task.Yield(); + yield return city; + } + } + + private static IEnumerable GetCities() + { + yield return new("Birmingham", "AL"); + yield return new("Anchorage", "AK"); + yield return new("Phoenix", "AZ"); + yield return new("Tucson", "AZ"); + yield return new("Mesa", "AZ"); + yield return new("Little Rock", "AR"); + yield return new("Los Angeles", "CA"); + yield return new("San Diego", "CA"); + yield return new("San Jose", "CA"); + yield return new("San Francisco", "CA"); + yield return new("Fresno", "CA"); + yield return new("Sacramento", "CA"); + yield return new("Long Beach", "CA"); + yield return new("Oakland", "CA"); + yield return new("Bakersfield", "CA"); + yield return new("Denver", "CO"); + yield return new("Washington", "DC"); + yield return new("Jacksonville", "FL"); + yield return new("Miami", "FL"); + yield return new("Tampa", "FL"); + yield return new("Atlanta", "GA"); + yield return new("Chicago", "IL"); + yield return new("Indianapolis", "IN"); + yield return new("Louisville", "KY"); + yield return new("New Orleans", "LA"); + yield return new("Baltimore", "MD"); + yield return new("Boston", "MA"); + yield return new("Detroit", "MI"); + yield return new("Minneapolis", "MN"); + yield return new("Kansas City", "MO"); + yield return new("Las Vegas", "NV"); + yield return new("Albuquerque", "NM"); + yield return new("New York", "NY"); + yield return new("Charlotte", "NC"); + yield return new("Raleigh", "NC"); + yield return new("Columbus", "OH"); + yield return new("Oklahoma City", "OK"); + yield return new("Tulsa", "OK"); + yield return new("Portland", "OR"); + yield return new("Philadelphia", "PA"); + yield return new("Memphis", "TN"); + yield return new("Nashville", "TN"); + yield return new("Houston", "TX"); + yield return new("San Antonio", "TX"); + yield return new("Dallas", "TX"); + yield return new("Austin", "TX"); + yield return new("Fort Worth", "TX"); + yield return new("El Paso", "TX"); + yield return new("Arlington", "TX"); + yield return new("Virginia Beach", "VA"); + yield return new("Seattle", "WA"); + yield return new("Milwaukee", "WI"); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/RangeTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/RangeTests.cs new file mode 100644 index 00000000000000..75413ef13fc7c4 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/RangeTests.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class RangeTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("count", () => AsyncEnumerable.Range(-1, -1)); + AssertExtensions.Throws("count", () => AsyncEnumerable.Range(-1, -1)); + AssertExtensions.Throws("count", () => AsyncEnumerable.Range(int.MaxValue - 1, 3)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + foreach (int start in new[] { int.MinValue, -1, 0, 1, 1_000_000 }) + { + foreach (int count in new[] { 0, 1, 3, 10 }) + { + await AssertEqual( + Enumerable.Range(start, count), + AsyncEnumerable.Range(start, count)); + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/RepeatTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/RepeatTests.cs new file mode 100644 index 00000000000000..f9bea5532d5b7f --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/RepeatTests.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class RepeatTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("count", () => AsyncEnumerable.Repeat("a", -1)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + foreach (int count in new[] { 0, 1, 10 }) + { + await AssertEqual( + Enumerable.Repeat(42, count), + AsyncEnumerable.Repeat(42, count)); + + await AssertEqual( + Enumerable.Repeat("test", count), + AsyncEnumerable.Repeat("test", count)); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ReverseTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ReverseTests.cs new file mode 100644 index 00000000000000..2c4678e3d3a95c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ReverseTests.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ReverseTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Reverse(null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.Reverse(), + source.Reverse()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(source.Reverse().WithCancellation(new CancellationToken(true))); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(source.Reverse()); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/RightJoinTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/RightJoinTests.cs new file mode 100644 index 00000000000000..1c5c02eb9831e1 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/RightJoinTests.cs @@ -0,0 +1,166 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class RightJoinTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("outer", () => AsyncEnumerable.RightJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, outer => outer, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null, inner => inner, (outer, inner) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, (Func)null, (outer, inner) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), outer => outer, inner => inner, (Func)null)); + + AssertExtensions.Throws("outer", () => AsyncEnumerable.RightJoin((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("inner", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("outerKeySelector", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("innerKeySelector", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, (Func>)null, async (outer, inner, ct) => outer + inner)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.RightJoin(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), async (outer, ct) => outer, async (inner, ct) => inner, (Func>)null)); + } + +#if NET + [Fact] + public async Task VariousValues_MatchesEnumerable_String() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 1000 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual( + values.RightJoin(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner), + source.RightJoin(source, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner)); + + await AssertEqual( + values.RightJoin(values, s => s.Length > 0 ? s[0] : ' ', s => s.Length > 1 ? s[1] : ' ', (outer, inner) => outer + inner), + source.RightJoin(source, async (s, ct) => s.Length > 0 ? s[0] : ' ', async (s, ct) => s.Length > 1 ? s[1] : ' ', async (outer, inner, ct) => outer + inner)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.RightJoin(source, outer => + { + cts.Cancel(); + return outer; + }, + inner => + { + return inner; + }, + (outer, inner) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.RightJoin(source, + async (outer, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.RightJoin(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return inner; + }, + async (outer, inner, ct) => + { + return outer + inner; + }).WithCancellation(cts.Token)); + }); + + await Assert.ThrowsAsync(async () => + { + CancellationTokenSource cts = new(); + await ConsumeAsync(source.RightJoin(source, + async (outer, ct) => + { + return outer; + }, + async (inner, ct) => + { + return inner; + }, + async (outer, inner, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return outer + inner; + }).WithCancellation(cts.Token)); + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable outer, inner; + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.RightJoin(inner, outer => outer, inner => inner, (outer, inner) => outer + inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + + outer = CreateSource(2, 4, 8, 16).Track(); + inner = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(outer.RightJoin(inner, async (outer, ct) => outer, async (inner, ct) => inner, async (outer, inner, ct) => outer + inner)); + Assert.Equal(5, outer.MoveNextAsyncCount); + Assert.Equal(4, outer.CurrentCount); + Assert.Equal(1, outer.DisposeAsyncCount); + Assert.Equal(5, inner.MoveNextAsyncCount); + Assert.Equal(4, inner.CurrentCount); + Assert.Equal(1, inner.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SelectManyTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SelectManyTests.cs new file mode 100644 index 00000000000000..ffdb4c8593026a --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SelectManyTests.cs @@ -0,0 +1,201 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SelectManyTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, i => Enumerable.Empty())); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, async (i, ct) => Enumerable.Empty())); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, i => AsyncEnumerable.Empty())); + + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, (i, index) => Enumerable.Empty())); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, async (i, index, ct) => Enumerable.Empty())); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, (i, index) => AsyncEnumerable.Empty())); + + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, i => Enumerable.Empty(), (i, s) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, i => AsyncEnumerable.Empty(), (i, s) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, async (i, ct) => Enumerable.Empty(), async (i, s, ct) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, i => AsyncEnumerable.Empty(), async (i, s, ct) => s)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, (i, index) => Enumerable.Empty(), (i, s) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, async (i, index, ct) => Enumerable.Empty(), async (i, s, ct) => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SelectMany(null, (i, index) => AsyncEnumerable.Empty(), async (i, s, ct) => s)); + + AssertExtensions.Throws("selector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>>)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("selector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>>)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null, (i, s) => s)); + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null, (i, s) => s)); + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>>)null, async (i, s, ct) => s)); + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null, async (i, s, ct) => s)); + + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null, (i, s) => s)); + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>>)null, async (i, s, ct) => s)); + AssertExtensions.Throws("collectionSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (Func>)null, async (i, s, ct) => s)); + + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), i => Enumerable.Empty(), (Func)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), i => AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), async (i, ct) => Enumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), i => AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (i, index) => Enumerable.Empty(), (Func)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), async (i, index, ct) => Enumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.SelectMany(AsyncEnumerable.Empty(), (i, index) => AsyncEnumerable.Empty(), (Func>)null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + Random rand = new(42); + foreach (int collectionSize in new[] { 0, 1, 10, 50 }) + { + foreach (int chunkSize in new[] { 1, 2, 3, 5, 60 }) + { + int[] ints = new int[collectionSize]; + FillRandom(rand, ints); + + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + Func>[] selectors = + [ + i => Array.Empty(), + i => [i], + i => [i, i * 2], + ]; + + foreach (Func> selector in selectors) + { + await AssertEqual( + ints.SelectMany(i => selector(i)), + source.SelectMany(i => selector(i))); + + await AssertEqual( + ints.SelectMany(i => selector(i)), + source.SelectMany(async (i, ct) => selector(i))); + + await AssertEqual( + ints.SelectMany(i => selector(i)), + source.SelectMany(i => selector(i).ToAsyncEnumerable())); + + await AssertEqual( + ints.SelectMany((i, index) => selector(i * index)), + source.SelectMany((i, index) => selector(i * index))); + + await AssertEqual( + ints.SelectMany((i, index) => selector(i * index)), + source.SelectMany(async (i, index, ct) => selector(i * index))); + + await AssertEqual( + ints.SelectMany((i, index) => selector(i * index)), + source.SelectMany((i, index) => selector(i * index).ToAsyncEnumerable())); + + await AssertEqual( + ints.SelectMany(i => selector(i), (i, result) => result), + source.SelectMany(i => selector(i), (i, result) => result)); + + await AssertEqual( + ints.SelectMany(i => selector(i), (i, result) => result), + source.SelectMany(async (i, ct) => selector(i), async (i, result, ct) => result)); + + await AssertEqual( + ints.SelectMany(i => selector(i), (i, result) => result), + source.SelectMany(i => selector(i).ToAsyncEnumerable(), (i, result) => result)); + + await AssertEqual( + ints.SelectMany(i => selector(i), (i, result) => result), + source.SelectMany(i => selector(i).ToAsyncEnumerable(), async (i, result, ct) => result)); + + await AssertEqual( + ints.SelectMany((i, index) => selector(i * index), (i, result) => result), + source.SelectMany((i, index) => selector(i * index), (i, result) => result)); + + await AssertEqual( + ints.SelectMany((i, index) => selector(i * index), (i, result) => result), + source.SelectMany(async (i, index, ct) => selector(i * index), async (i, result, ct) => result)); + + await AssertEqual( + ints.SelectMany((i, index) => selector(i * index), (i, result) => result), + source.SelectMany((i, index) => selector(i * index).ToAsyncEnumerable(), async (i, result, ct) => result)); + } + } + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Validate(source.SelectMany(i => new[] { i })); + await Validate(source.SelectMany(async (i, ct) => new[] { i })); + await Validate(source.SelectMany(i => new[] { i }.ToAsyncEnumerable())); + + await Validate(source.SelectMany((i, index) => new[] { i })); + await Validate(source.SelectMany(async (i, index, ct) => new[] { i })); + await Validate(source.SelectMany((i, index) => new[] { i }.ToAsyncEnumerable())); + + await Validate(source.SelectMany(i => new[] { i }, (i, result) => result)); + await Validate(source.SelectMany(async (i, ct) => new[] { i }, async (i, result, ct) => result)); + await Validate(source.SelectMany(i => new[] { i }.ToAsyncEnumerable(), (i, result) => result)); + await Validate(source.SelectMany(i => new[] { i }.ToAsyncEnumerable(), async (i, result, ct) => result)); + + await Validate(source.SelectMany((i, index) => new[] { i })); + await Validate(source.SelectMany(async (i, index, ct) => new[] { i }, async (i, result, ct) => result)); + await Validate(source.SelectMany((i, index) => new[] { i }.ToAsyncEnumerable(), async (i, result, ct) => result)); + + static async Task Validate(IAsyncEnumerable source) + { + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in source.WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.SelectMany(i => new[] { i, i + 1, i * 2 })); + await Validate(source => source.SelectMany(async (i, ct) => new[] { i, i + 1, i * 2 })); + await Validate(source => source.SelectMany(i => new[] { i, i + 1, i * 2 }.ToAsyncEnumerable())); + await Validate(source => source.SelectMany((i, index) => new[] { i, i + 1, i * 2 })); + await Validate(source => source.SelectMany(async (i, index, ct) => new[] { i, i + 1, i * 2 })); + await Validate(source => source.SelectMany((i, index) => new[] { i, i + 1, i * 2 }.ToAsyncEnumerable())); + await Validate(source => source.SelectMany(i => new[] { i, i + 1, i * 2 }, (i, result) => result)); + await Validate(source => source.SelectMany(async (i, ct) => new[] { i, i + 1, i * 2 }, async (i, result, ct) => result)); + await Validate(source => source.SelectMany(i => new[] { i, i + 1, i * 2 }.ToAsyncEnumerable(), (i, result) => result)); + await Validate(source => source.SelectMany(i => new[] { i, i + 1, i * 2 }.ToAsyncEnumerable(), async (i, result, ct) => result)); + await Validate(source => source.SelectMany((i, index) => new[] { i, i + 1, i * 2 }, (i, result) => result)); + await Validate(source => source.SelectMany(async (i, index, ct) => new[] { i, i + 1, i * 2 }, async (i, result, ct) => result)); + await Validate(source => source.SelectMany((i, index) => new[] { i, i + 1, i * 2 }.ToAsyncEnumerable(), async (i, result, ct) => result)); + + async static Task Validate(Func, IAsyncEnumerable> factory) + { + TrackingAsyncEnumerable source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SelectTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SelectTests.cs new file mode 100644 index 00000000000000..9e82c46e73cf98 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SelectTests.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SelectTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Select(null, i => i.ToString())); + AssertExtensions.Throws("source", () => AsyncEnumerable.Select(null, (i, index) => i.ToString())); + AssertExtensions.Throws("source", () => AsyncEnumerable.Select(null, async (i, ct) => i.ToString())); + AssertExtensions.Throws("source", () => AsyncEnumerable.Select(null, async (i, index, ct) => i.ToString())); + + AssertExtensions.Throws("selector", () => AsyncEnumerable.Select(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.Select(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.Select(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("selector", () => AsyncEnumerable.Select(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + await AssertEqual( + ints.Select(i => i.ToString()), + source.Select(i => i.ToString())); + + await AssertEqual( + ints.Select((i, index) => (i + index).ToString()), + source.Select((i, index) => (i + index).ToString())); + + await AssertEqual( + ints.Select(i => i.ToString()), + source.Select(async (int i, CancellationToken ct) => i.ToString())); + + await AssertEqual( + ints.Select((i, index) => (i + index).ToString()), + source.Select(async (i, index, ct) => (i + index).ToString())); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + + await Validate(source.Select(i => i)); + await Validate(source.Select((i, index) => i)); + await Validate(source.Select(async (int i, CancellationToken index) => i)); + await Validate(source.Select(async (i, index, ct) => i)); + + static async Task Validate(IAsyncEnumerable source) + { + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in source.WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.Select(i => i)); + await Validate(source => source.Select((i, index) => i)); + await Validate(source => source.Select(async (int i, CancellationToken cancellationToken) => i)); + await Validate(source => source.Select(async (i, index, ct) => i)); + + async Task Validate(Func, IAsyncEnumerable> factory) + { + TrackingAsyncEnumerable source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SequenceEqualAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SequenceEqualAsyncTests.cs new file mode 100644 index 00000000000000..53b0cedb2ffafb --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SequenceEqualAsyncTests.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SequenceEqualAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.SequenceEqualAsync(null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.SequenceEqualAsync(AsyncEnumerable.Empty(), null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 10 }) + { + int[] values = new int[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + foreach (IEqualityComparer comparer in new[] { EqualityComparer.Default, null, OddEvenComparer }) + { + Assert.Equal( + values.SequenceEqual(values, comparer), + await source.SequenceEqualAsync(source, comparer)); + + Assert.Equal( + values.SequenceEqual(values.Concat([1]), comparer), + await source.SequenceEqualAsync(source.Concat(new[] { 1 }.ToAsyncEnumerable()), comparer)); + + Assert.Equal( + values.SequenceEqual(new[] { 42 }.Concat(values), comparer), + await source.SequenceEqualAsync(new[] { 1 }.ToAsyncEnumerable().Concat(source), comparer)); + } + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(1, 3, 5); + await Assert.ThrowsAsync(async () => await source.SequenceEqualAsync(source, null, new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable first, second; + + first = CreateSource(1, 3, 5).Track(); + second = CreateSource(1, 3, 5).Track(); + Assert.True(await first.SequenceEqualAsync(second)); + Assert.Equal(4, first.MoveNextAsyncCount); + Assert.Equal(3, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + + first = CreateSource(1).Track(); + second = CreateSource(1, 3, 5).Track(); + Assert.False(await first.SequenceEqualAsync(second)); + Assert.Equal(2, first.MoveNextAsyncCount); + Assert.Equal(1, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(2, second.MoveNextAsyncCount); + Assert.Equal(1, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SingleAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SingleAsyncTests.cs new file mode 100644 index 00000000000000..142361d3189467 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SingleAsyncTests.cs @@ -0,0 +1,137 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SingleAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleAsync(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleAsync(null, async (i, ct) => i % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SingleAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SingleAsync(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Fact] + public async Task EmptyInputs_Throws() + { + ValueTask first; + + first = AsyncEnumerable.Empty().SingleAsync(); + await Assert.ThrowsAsync(async () => await first); + + first = AsyncEnumerable.Empty().SingleAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = AsyncEnumerable.Empty().SingleAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = new int[] { 1, 3, 5 }.ToAsyncEnumerable().SingleAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + + first = new int[] { 1, 3, 5 }.ToAsyncEnumerable().SingleAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await first); + } + + [Fact] + public async Task DoubleInputs_Throws() + { + ValueTask single; + + single = new int[] { 1, 2 }.ToAsyncEnumerable().SingleAsync(); + await Assert.ThrowsAsync(async () => await single); + + single = new int[] { 1, 2, 1, 2 }.ToAsyncEnumerable().SingleAsync(i => i % 2 == 0); + await Assert.ThrowsAsync(async () => await single); + + single = new int[] { 1, 2, 1, 2 }.ToAsyncEnumerable().SingleAsync(async (i, ct) => i % 2 == 0); + await Assert.ThrowsAsync(async () => await single); + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + if (values.Length == 1) + { + Assert.Equal( + values.Single(), + await source.SingleAsync()); + } + + Func predicate = i => i == values.Last(); + + Assert.Equal( + values.Single(predicate), + await source.SingleAsync(predicate)); + + Assert.Equal( + values.Single(predicate), + await source.SingleAsync(async (i, ct) => predicate(i))); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.SingleAsync(new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.SingleAsync(x => + { + cts.Cancel(); + return x > 32; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.SingleAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x > 32; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(2).Track(); + await source.SingleAsync(); + Assert.Equal(2, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.SingleAsync(i => i == 8); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).Track(); + await source.SingleAsync(async (i, ct) => i == 2); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SingleOrDefaultAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SingleOrDefaultAsyncTests.cs new file mode 100644 index 00000000000000..fb977e13012e77 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SingleOrDefaultAsyncTests.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SingleOrDefaultAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleOrDefaultAsync(null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleOrDefaultAsync(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleOrDefaultAsync(null, async (i, ct) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleOrDefaultAsync(null, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleOrDefaultAsync(null, i => i % 2 == 0, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SingleOrDefaultAsync(null, async (i, ct) => i % 2 == 0, 42)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SingleOrDefaultAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SingleOrDefaultAsync(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SingleOrDefaultAsync(AsyncEnumerable.Empty(), (Func)null, 42)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SingleOrDefaultAsync(AsyncEnumerable.Empty(), (Func>)null, 42)); + } + + [Fact] + public async Task EmptyInputs_DefaultValueReturned() + { + Assert.Equal(0, await AsyncEnumerable.Empty().SingleOrDefaultAsync()); + Assert.Equal(42, await AsyncEnumerable.Empty().SingleOrDefaultAsync(42)); + Assert.Equal(0, await AsyncEnumerable.Empty().SingleOrDefaultAsync(i => i % 2 == 0)); + Assert.Equal(42, await AsyncEnumerable.Empty().SingleOrDefaultAsync(i => i % 2 == 0, 42)); + Assert.Equal(0, await AsyncEnumerable.Empty().SingleOrDefaultAsync(async (i, ct) => i % 2 == 0)); + Assert.Equal(42, await AsyncEnumerable.Empty().SingleOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + + IAsyncEnumerable source = new int[] { 1, 3, 5 }.ToAsyncEnumerable(); + Assert.Equal(0, await source.SingleOrDefaultAsync(i => i % 2 == 0)); + Assert.Equal(42, await source.SingleOrDefaultAsync(i => i % 2 == 0, 42)); + Assert.Equal(0, await source.SingleOrDefaultAsync(async (i, ct) => i % 2 == 0)); + Assert.Equal(42, await source.SingleOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + } + + [Fact] + public async Task DoubleInputs_Throws() + { + await Validate(new int[] { 1, 2 }.ToAsyncEnumerable().SingleOrDefaultAsync()); + await Validate(new int[] { 1, 2, 1, 2 }.ToAsyncEnumerable().SingleOrDefaultAsync(i => i % 2 == 0)); + await Validate(new int[] { 1, 2, 1, 2 }.ToAsyncEnumerable().SingleOrDefaultAsync(async (i, ct) => i % 2 == 0)); + + await Validate(new int[] { 1, 2 }.ToAsyncEnumerable().SingleOrDefaultAsync(42)); + await Validate(new int[] { 1, 2, 1, 2 }.ToAsyncEnumerable().SingleOrDefaultAsync(i => i % 2 == 0, 42)); + await Validate(new int[] { 1, 2, 1, 2 }.ToAsyncEnumerable().SingleOrDefaultAsync(async (i, ct) => i % 2 == 0, 42)); + + static async Task Validate(ValueTask task) + { + await Assert.ThrowsAsync(async () => await task); + } + } + + [Theory] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { 1, 3, 5, 7 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + if (values.Length == 1) + { + Assert.Equal( + values.SingleOrDefault(), + await source.SingleOrDefaultAsync()); + } + + Func predicate = i => i == values.Last(); + + Assert.Equal( + values.SingleOrDefault(predicate), + await source.SingleOrDefaultAsync(predicate)); + + Assert.Equal( + values.SingleOrDefault(predicate), + await source.SingleOrDefaultAsync(async (i, ct) => predicate(i))); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + CancellationTokenSource cts; + + await Assert.ThrowsAsync(async () => await source.SingleOrDefaultAsync(new CancellationToken(true))); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.SingleOrDefaultAsync(x => + { + cts.Cancel(); + return x > 32; + }, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.SingleOrDefaultAsync(async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x > 32; + }, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(s => s.SingleOrDefaultAsync(), count: 1); + await Validate(s => s.SingleOrDefaultAsync(42), count: 1); + await Validate(s => s.SingleOrDefaultAsync(i => i == 1)); + await Validate(s => s.SingleOrDefaultAsync(i => i == 1, 42)); + await Validate(s => s.SingleOrDefaultAsync(async (i, ct) => i == 1)); + await Validate(s => s.SingleOrDefaultAsync(async (i, ct) => i == 1, 42)); + + static async Task Validate(Func, ValueTask> func, int count = 4) + { + TrackingAsyncEnumerable source = CreateSource(Enumerable.Range(1, count).ToArray()).Track(); + await func(source); + Assert.Equal(count + 1, source.MoveNextAsyncCount); + Assert.Equal(count, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SkipLastTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SkipLastTests.cs new file mode 100644 index 00000000000000..1da94e367d3d0e --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SkipLastTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SkipLastTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.SkipLast((IAsyncEnumerable)null, 42)); + } + +#if NET + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + foreach (int count in new[] { -1, 0, 1, 2, 10 }) + { + await AssertEqual( + ints.SkipLast(count), + source.SkipLast(count)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.SkipLast(1).WithCancellation(new CancellationToken(true)))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.SkipLast(0)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.SkipLast(3)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SkipTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SkipTests.cs new file mode 100644 index 00000000000000..e3f1b3d121b0d8 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SkipTests.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SkipTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Skip((IAsyncEnumerable)null, 42)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + foreach (int count in new[] { -1, 0, 1, 2, 10 }) + { + await AssertEqual( + ints.Skip(count), + source.Skip(count)); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.Skip(1).WithCancellation(new CancellationToken(true)))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.Skip(0)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.Skip(3)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SkipWhileTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SkipWhileTests.cs new file mode 100644 index 00000000000000..1da93416d22a3e --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SkipWhileTests.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SkipWhileTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.SkipWhile((IAsyncEnumerable)null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SkipWhile((IAsyncEnumerable)null, (i, index) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SkipWhile((IAsyncEnumerable)null, async (i, ct) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SkipWhile((IAsyncEnumerable)null, async (i, index, ct) => i % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SkipWhile(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SkipWhile(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SkipWhile(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.SkipWhile(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + foreach (bool b in TrueFalseBools) + { + await AssertEqual( + ints.SkipWhile(i => b), + source.SkipWhile(i => b)); + + await AssertEqual( + ints.SkipWhile(i => b), + source.SkipWhile(async (i, ct) => b)); + + await AssertEqual( + ints.SkipWhile((i, index) => b), + source.SkipWhile((i, index) => b)); + + await AssertEqual( + ints.SkipWhile((i, index) => b), + source.SkipWhile(async (i, index, ct) => b)); + } + + await AssertEqual( + ints.SkipWhile((i, index) => index < 2), + source.SkipWhile((i, index) => index < 2)); + + await AssertEqual( + ints.SkipWhile((i, index) => index < 2), + source.SkipWhile(async (i, index, ct) => index < 2)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.SkipWhile(i => true).WithCancellation(new CancellationToken(true)))); + + CancellationTokenSource cts; + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.SkipWhile(i => + { + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.SkipWhile(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.SkipWhile((i, index) => + { + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.SkipWhile(async (i, index, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + foreach (bool useAsync in TrueFalseBools) + { + foreach (bool useIndex in TrueFalseBools) + { + foreach (bool trueFalse in TrueFalseBools) + { + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync((useAsync, useIndex) switch + { + (false, false) => source.SkipWhile(i => trueFalse), + (false, true) => source.SkipWhile((i, index) => trueFalse), + (true, false) => source.SkipWhile(async (i, ct) => trueFalse), + (true, true) => source.SkipWhile(async (i, index, ct) => trueFalse), + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/SumAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/SumAsyncTests.cs new file mode 100644 index 00000000000000..965361c580fcd3 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/SumAsyncTests.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Globalization; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class SumAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.SumAsync((IAsyncEnumerable)null)); + } + [Fact] + public async Task EmptyInputs_NonNullable_Throws() + { + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + Assert.Equal(0, await AsyncEnumerable.SumAsync(AsyncEnumerable.Empty())); + } + + [Theory] + [InlineData(new int[] { 0 })] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -int.MaxValue, int.MaxValue })] + [InlineData(new int[] { -1, -2, -3 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal(values.Select(i => (int)i).Sum(), await source.Select(i => (int)i).SumAsync()); + Assert.Equal(values.Select(i => (long)i).Sum(), await source.Select(i => (long)i).SumAsync()); + Assert.Equal(values.Select(i => (float)i).Sum(), await source.Select(i => (float)i).SumAsync()); + Assert.Equal(values.Select(i => (double)i).Sum(), await source.Select(i => (double)i).SumAsync()); + Assert.Equal(values.Select(i => (decimal)i).Sum(), await source.Select(i => (decimal)i).SumAsync()); + + Assert.Equal(values.Select(i => (int?)i).Sum(), await source.Select(i => (int?)i).SumAsync()); + Assert.Equal(values.Select(i => (long?)i).Sum(), await source.Select(i => (long?)i).SumAsync()); + Assert.Equal(values.Select(i => (float?)i).Sum(), await source.Select(i => (float?)i).SumAsync()); + Assert.Equal(values.Select(i => (double?)i).Sum(), await source.Select(i => (double?)i).SumAsync()); + Assert.Equal(values.Select(i => (decimal?)i).Sum(), await source.Select(i => (decimal?)i).SumAsync()); + + Assert.Equal(values.Select(i => (int?)i).Sum(), await source.SelectMany(i => [i, null]).SumAsync()); + Assert.Equal(values.Select(i => (long?)i).Sum(), await source.SelectMany(i => [i, null]).SumAsync()); + Assert.Equal(values.Select(i => (float?)i).Sum(), await source.SelectMany(i => [i, null]).SumAsync()); + Assert.Equal(values.Select(i => (double?)i).Sum(), await source.SelectMany(i => [i, null]).SumAsync()); + Assert.Equal(values.Select(i => (decimal?)i).Sum(), await source.SelectMany(i => [i, null]).SumAsync()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal)i).SumAsync(new CancellationToken(true))); + + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (int?)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (long?)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (float?)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (double?)i).SumAsync(new CancellationToken(true))); + await Assert.ThrowsAsync(async () => await CreateSource(2, 4).Select(i => (decimal?)i).SumAsync(new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.Select(i => (int)i).SumAsync()); + await Validate(source => source.Select(i => (long)i).SumAsync()); + await Validate(source => source.Select(i => (float)i).SumAsync()); + await Validate(source => source.Select(i => (double)i).SumAsync()); + await Validate(source => source.Select(i => (decimal)i).SumAsync()); + + await Validate(source => source.Select(i => (int?)i).SumAsync()); + await Validate(source => source.Select(i => (long?)i).SumAsync()); + await Validate(source => source.Select(i => (float?)i).SumAsync()); + await Validate(source => source.Select(i => (double?)i).SumAsync()); + await Validate(source => source.Select(i => (decimal?)i).SumAsync()); + + static async Task Validate(Func, ValueTask> factory) + { + TrackingAsyncEnumerable source; + + source = CreateSource(2, 4, 8, 16).Track(); + await factory(source); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(2, 4, 8, 16).AppendException(new FormatException()).Track(); + await Assert.ThrowsAsync(async () => await factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/System.Linq.AsyncEnumerable.Tests.csproj b/src/libraries/System.Linq.AsyncEnumerable/tests/System.Linq.AsyncEnumerable.Tests.csproj new file mode 100644 index 00000000000000..1a819d87b173b9 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/System.Linq.AsyncEnumerable.Tests.csproj @@ -0,0 +1,82 @@ + + + $(NetCoreAppCurrent);$(NetFrameworkMinimum) + $(NoWarn);CS1998 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/TakeLastTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/TakeLastTests.cs new file mode 100644 index 00000000000000..0b723c7268befc --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/TakeLastTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class TakeLastTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.TakeLast((IAsyncEnumerable)null, 42)); + } + +#if NET + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + foreach (int count in new[] { -1, 0, 1, 2, 10 }) + { + await AssertEqual( + ints.TakeLast(count), + source.TakeLast(count)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.TakeLast(1).WithCancellation(new CancellationToken(true)))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.TakeLast(1)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.TakeLast(3)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/TakeTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/TakeTests.cs new file mode 100644 index 00000000000000..72582de10a2205 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/TakeTests.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class TakeTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Take((IAsyncEnumerable)null, 42)); + AssertExtensions.Throws("source", () => AsyncEnumerable.Take((IAsyncEnumerable)null, new Range(new(0), new(42)))); + } + + [Fact] + public void TakeNothing_ReturnsEmpty() + { + Assert.Same(AsyncEnumerable.Empty(), AsyncEnumerable.Take(new int[] { 1, 2, 3 }.ToAsyncEnumerable(), 0)); + Assert.Same(AsyncEnumerable.Empty(), AsyncEnumerable.Take(new int[] { 1, 2, 3 }.ToAsyncEnumerable(), -1)); + Assert.Same(AsyncEnumerable.Empty(), AsyncEnumerable.Take(new int[] { 1, 2, 3 }.ToAsyncEnumerable(), new Range(new(0), new(0)))); + Assert.Same(AsyncEnumerable.Empty(), AsyncEnumerable.Take(new int[] { 1, 2, 3 }.ToAsyncEnumerable(), new Range(new(0, fromEnd: true), new(0, fromEnd: true)))); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + foreach (int count in new[] { 0, 1, 2, 10 }) + { + await AssertEqual( + ints.Take(count), + source.Take(count)); + +#if NET + await AssertEqual( + ints.Take(..count), + source.Take(..count)); + + await AssertEqual( + ints.Take(count..), + source.Take(count..)); + + await AssertEqual( + ints.Take(..^count), + source.Take(..^count)); + + await AssertEqual( + ints.Take(^count..), + source.Take(^count..)); + + await AssertEqual( + ints.Take(3..(3+count)), + source.Take(3..(3+count))); +#endif + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.Take(1).WithCancellation(new CancellationToken(true)))); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.Take(new Range(new(0), new(3))).WithCancellation(new CancellationToken(true)))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.Take(1)); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.Take(3)); + Assert.Equal(3, source.MoveNextAsyncCount); + Assert.Equal(3, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(source.Take(new Range(new(0), new(1)))); + Assert.Equal(1, source.MoveNextAsyncCount); + Assert.Equal(1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/TakeWhileTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/TakeWhileTests.cs new file mode 100644 index 00000000000000..fbcb2a97f470e6 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/TakeWhileTests.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class TakeWhileTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.TakeWhile((IAsyncEnumerable)null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.TakeWhile((IAsyncEnumerable)null, (i, index) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.TakeWhile((IAsyncEnumerable)null, async (i, ct) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.TakeWhile((IAsyncEnumerable)null, async (i, index, ct) => i % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.TakeWhile(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.TakeWhile(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.TakeWhile(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.TakeWhile(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + foreach (bool b in TrueFalseBools) + { + await AssertEqual( + ints.TakeWhile(i => b), + source.TakeWhile(i => b)); + + await AssertEqual( + ints.TakeWhile(i => b), + source.TakeWhile(async (i, ct) => b)); + + await AssertEqual( + ints.TakeWhile((i, index) => b), + source.TakeWhile((i, index) => b)); + + await AssertEqual( + ints.TakeWhile((i, index) => b), + source.TakeWhile(async (i, index, ct) => b)); + } + + await AssertEqual( + ints.TakeWhile((i, index) => index < 2), + source.TakeWhile((i, index) => index < 2)); + + await AssertEqual( + ints.TakeWhile((i, index) => index < 2), + source.TakeWhile(async (i, index, ct) => index < 2)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.TakeWhile(i => true).WithCancellation(new CancellationToken(true)))); + + CancellationTokenSource cts; + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.TakeWhile(i => + { + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.TakeWhile(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.TakeWhile((i, index) => + { + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + + cts = new CancellationTokenSource(); + await Assert.ThrowsAsync(async () => await ConsumeAsync(source.TakeWhile(async (i, index, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return true; + }).WithCancellation(cts.Token))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + + foreach (bool useAsync in TrueFalseBools) + { + foreach (bool useIndex in TrueFalseBools) + { + foreach (bool trueFalse in TrueFalseBools) + { + source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync((useAsync, useIndex) switch + { + (false, false) => source.TakeWhile(i => trueFalse), + (false, true) => source.TakeWhile((i, index) => trueFalse), + (true, false) => source.TakeWhile(async (i, ct) => trueFalse), + (true, true) => source.TakeWhile(async (i, index, ct) => trueFalse), + }); + Assert.Equal(trueFalse ? 5 : 1, source.MoveNextAsyncCount); + Assert.Equal(trueFalse ? 4 : 1, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ToArrayAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ToArrayAsyncTests.cs new file mode 100644 index 00000000000000..970ddf3b877f84 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ToArrayAsyncTests.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ToArrayAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ToArrayAsync(null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 1, 1, 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8, 6, -1, 5, 14 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.ToArray(), + await source.ToArrayAsync()); + + Assert.Equal( + values.Select(i => i.ToString()).ToArray(), + await source.Select(i => i.ToString()).ToArrayAsync()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await source.ToArrayAsync(new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await source.ToArrayAsync(); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ToAsyncEnumerableTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ToAsyncEnumerableTests.cs new file mode 100644 index 00000000000000..da851de0037d12 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ToAsyncEnumerableTests.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ToAsyncEnumerableTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ToAsyncEnumerable(null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 1, 1, 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8, 6, -1, 5, 14 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + await AssertEqual(values, (await source.ToArrayAsync()).ToAsyncEnumerable()); + await AssertEqual(values, (await source.ToListAsync()).ToAsyncEnumerable()); + await AssertEqual(values, new ReadOnlyCollection(await source.ToListAsync()).ToAsyncEnumerable()); + await AssertEqual(values, new Queue(await source.ToListAsync()).ToAsyncEnumerable()); + } + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await source.ToArrayAsync(); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ToDictionaryAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ToDictionaryAsyncTests.cs new file mode 100644 index 00000000000000..ac5ddbf4beec5c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ToDictionaryAsyncTests.cs @@ -0,0 +1,236 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ToDictionaryAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ToDictionaryAsync((IAsyncEnumerable>)null)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToDictionaryAsync((IAsyncEnumerable)null, s => s.Length)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToDictionaryAsync((IAsyncEnumerable)null, async (s, ct) => s.Length)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToDictionaryAsync((IAsyncEnumerable)null, s => s.Length, s => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToDictionaryAsync((IAsyncEnumerable)null, async (s, ct) => s.Length, async (s, ct) => s)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToDictionaryAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToDictionaryAsync(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToDictionaryAsync(AsyncEnumerable.Empty(), (Func)null, s => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToDictionaryAsync(AsyncEnumerable.Empty(), (Func>)null, async (s, ct) => s)); + + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.ToDictionaryAsync(AsyncEnumerable.Empty(), s => s.Length, (Func)null)); + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.ToDictionaryAsync(AsyncEnumerable.Empty(), async (s, ct) => s.Length, (Func>)null)); + } + + [Fact] + public async Task Duplicates_Throws() + { + ValueTask> result; + + result = CreateSource("a", "b", "a").ToDictionaryAsync(s => s); + await Assert.ThrowsAsync(async () => await result); + + result = CreateSource("a", "b", "c").ToDictionaryAsync(s => "a"); + await Assert.ThrowsAsync(async () => await result); + + result = CreateSource("a", "b", "c").ToDictionaryAsync(async (s, ct) => "a"); + await Assert.ThrowsAsync(async () => await result); + + result = CreateSource("a", "b", "c").ToDictionaryAsync(s => "a", s => s); + await Assert.ThrowsAsync(async () => await result); + + result = CreateSource("a", "b", "c").ToDictionaryAsync(async (s, ct) => "a", async (s, ct) => s); + await Assert.ThrowsAsync(async () => await result); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 100 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + for (int i = 0; i < length; i++) + { + values[i] = values[i] + (char)('A' + i); + } + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + foreach (IEqualityComparer comparer in new IEqualityComparer[] { null, EqualityComparer.Default, StringComparer.OrdinalIgnoreCase }) + { +#if NET + Assert.Equal( + values.Select(s => KeyValuePair.Create(s, s)).ToDictionary(comparer), + await source.Select(s => KeyValuePair.Create(s, s)).ToDictionaryAsync(comparer)); + + Assert.Equal( + values.Select(s => (s, s)).ToDictionary(comparer), + await source.Select(s => (s, s)).ToDictionaryAsync(comparer)); +#endif + + Assert.Equal( + values.ToDictionary(s => s + s, comparer), + await source.ToDictionaryAsync(s => s + s, comparer)); + + Assert.Equal( + values.ToDictionary(s => s + s, comparer), + await source.ToDictionaryAsync(async (s, ct) => + { + await Task.Yield(); + return s + s; + }, comparer)); + + Assert.Equal( + values.ToDictionary(s => s + s, s => s.Length > 0 ? s.Substring(1) : "", comparer), + await source.ToDictionaryAsync(s => s + s, s => s.Length > 0 ? s.Substring(1) : "", comparer)); + + Assert.Equal( + values.ToDictionary(s => s + s, s => s.Length > 0 ? s.Substring(1) : "", comparer), + await source.ToDictionaryAsync(async (s, ct) => + { + await Task.Yield(); + return s + s; + }, async (s, ct) => + { + await Task.Yield(); + return s.Length > 0 ? s.Substring(1) : ""; + }, comparer)); + } + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(i => i, null, new CancellationToken(true))); + + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(i => + { + cts.Cancel(); + return i; + }, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(i => + { + cts.Cancel(); + return i; + }, i => i, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(i => i, i => + { + cts.Cancel(); + return i; + }, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, async (i, ct) => i, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToDictionaryAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + return i; + }, async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, null, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int keySelectorCount, elementSelectorCount; + + keySelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToDictionaryAsync(i => + { + keySelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToDictionaryAsync(async (i, ct) => + { + keySelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = elementSelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToDictionaryAsync(i => + { + keySelectorCount++; + return i; + }, i => + { + elementSelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = elementSelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToDictionaryAsync(async (i, ct) => + { + keySelectorCount++; + return i; + }, async (i, ct) => + { + elementSelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(4, elementSelectorCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ToHashSetAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ToHashSetAsyncTests.cs new file mode 100644 index 00000000000000..0284492433b04c --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ToHashSetAsyncTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ToHashSetAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ToHashSetAsync(null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 1, 1, 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8, 6, -1, 5, 14 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + new HashSet(values), + await source.ToHashSetAsync()); + + Assert.Equal( + new HashSet(values, OddEvenComparer).OrderBy(s => s), + (await source.ToHashSetAsync(OddEvenComparer)).OrderBy(s => s)); + + Assert.Equal( + new HashSet(values.Select(i => i.ToString())), + await source.Select(i => i.ToString()).ToHashSetAsync()); + + Assert.Equal( + new HashSet(values.Select(i => i.ToString()), LengthComparer).OrderBy(s => s), + (await source.Select(i => i.ToString()).ToHashSetAsync(LengthComparer)).OrderBy(s => s)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await source.ToHashSetAsync(null, new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await source.ToHashSetAsync(); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ToListAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ToListAsyncTests.cs new file mode 100644 index 00000000000000..e0b7b8cf57477b --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ToListAsyncTests.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ToListAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ToListAsync(null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 1 })] + [InlineData(new int[] { 1, 1, 1 })] + [InlineData(new int[] { 2, 4, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8, 6, -1, 5, 14 })] + public async Task VariousValues_MatchesEnumerable(int[] values) + { + foreach (IAsyncEnumerable source in CreateSources(values)) + { + Assert.Equal( + values.ToList(), + await source.ToListAsync()); + + Assert.Equal( + values.Select(i => i.ToString()).ToList(), + await source.Select(i => i.ToString()).ToListAsync()); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await source.ToListAsync(new CancellationToken(true))); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source = CreateSource(2, 4, 8, 16).Track(); + await source.ToListAsync(); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ToLookupAsyncTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ToLookupAsyncTests.cs new file mode 100644 index 00000000000000..c66a7c228efcf7 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ToLookupAsyncTests.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ToLookupAsyncTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.ToLookupAsync((IAsyncEnumerable)null, s => s.Length)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToLookupAsync((IAsyncEnumerable)null, async (s, ct) => s.Length)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToLookupAsync((IAsyncEnumerable)null, s => s.Length, s => s)); + AssertExtensions.Throws("source", () => AsyncEnumerable.ToLookupAsync((IAsyncEnumerable)null, async (s, ct) => s.Length, async (s, ct) => s)); + + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToLookupAsync(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToLookupAsync(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToLookupAsync(AsyncEnumerable.Empty(), (Func)null, s => s)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.ToLookupAsync(AsyncEnumerable.Empty(), (Func>)null, async (s, ct) => s)); + + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.ToLookupAsync(AsyncEnumerable.Empty(), s => s.Length, (Func)null)); + AssertExtensions.Throws("elementSelector", () => AsyncEnumerable.ToLookupAsync(AsyncEnumerable.Empty(), async (s, ct) => s.Length, (Func>)null)); + } + + [Fact] + public async Task VariousValues_MatchesEnumerable() + { + Random rand = new(42); + foreach (int length in new[] { 0, 1, 2, 100 }) + { + string[] values = new string[length]; + FillRandom(rand, values); + + foreach (IAsyncEnumerable source in CreateSources(values)) + { + foreach (IEqualityComparer comparer in new[] { null, EqualityComparer.Default, OddEvenComparer }) + { + AssertEqual( + values.ToLookup(s => s.Length, comparer), + await source.ToLookupAsync(s => s.Length, comparer)); + + AssertEqual( + values.ToLookup(s => s.Length, comparer), + await source.ToLookupAsync(async (s, ct) => + { + await Task.Yield(); + return s.Length; + }, comparer)); + + AssertEqual( + values.ToLookup(s => s.Length, s => s.Length > 0 ? s.Substring(1) : "", comparer), + await source.ToLookupAsync(s => s.Length, s => s.Length > 0 ? s.Substring(1) : "", comparer)); + + AssertEqual( + values.ToLookup(s => s.Length, s => s.Length > 0 ? s.Substring(1) : "", comparer), + await source.ToLookupAsync(async (s, ct) => + { + await Task.Yield(); + return s.Length; + }, async (s, ct) => + { + await Task.Yield(); + return s.Length > 0 ? s.Substring(1) : ""; + }, comparer)); + + static void AssertEqual( + ILookup expected, + ILookup actual) + { + Assert.Equal(expected.Count, actual.Count); + Assert.Equal(expected.SelectMany(kvp => kvp), actual.SelectMany(kvp => kvp)); + + foreach (IGrouping g in expected) + { + Assert.True(actual.Contains(g.Key)); + Assert.Equal(g, actual[g.Key]); + } + + foreach (IGrouping g in actual) + { + Assert.True(expected.Contains(g.Key)); + Assert.Equal(g, expected[g.Key]); + } + + foreach (IGrouping g in (IEnumerable)actual) + { + Assert.True(expected.Contains(g.Key)); + Assert.Equal(g, expected[g.Key]); + } + } + } + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(2, 4, 8, 16); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(i => i, null, new CancellationToken(true))); + + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(i => + { + cts.Cancel(); + return i; + }, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(i => + { + cts.Cancel(); + return i; + }, i => i, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(i => i, i => + { + cts.Cancel(); + return i; + }, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, async (i, ct) => i, null, cts.Token)); + + cts = new(); + await Assert.ThrowsAsync(async () => await source.ToLookupAsync(async (i, ct) => + { + Assert.Equal(cts.Token, ct); + return i; + }, async (i, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return i; + }, null, cts.Token)); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable source; + int keySelectorCount, elementSelectorCount; + + keySelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToLookupAsync(i => + { + keySelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToLookupAsync(async (i, ct) => + { + keySelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = elementSelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToLookupAsync(i => + { + keySelectorCount++; + return i; + }, i => + { + elementSelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + + keySelectorCount = elementSelectorCount = 0; + source = CreateSource(2, 4, 8, 16).Track(); + await source.ToLookupAsync(async (i, ct) => + { + keySelectorCount++; + return i; + }, async (i, ct) => + { + elementSelectorCount++; + return i; + }); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + Assert.Equal(4, keySelectorCount); + Assert.Equal(4, elementSelectorCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/UnionByTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/UnionByTests.cs new file mode 100644 index 00000000000000..e363dda433c1e6 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/UnionByTests.cs @@ -0,0 +1,128 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class UnionByTests : AsyncEnumerableTests + { +#if NET + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.UnionBy((IAsyncEnumerable)null, AsyncEnumerable.Empty(), x => x.ToString())); + AssertExtensions.Throws("second", () => AsyncEnumerable.UnionBy(AsyncEnumerable.Empty(), null, x => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.UnionBy(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null)); + + AssertExtensions.Throws("first", () => AsyncEnumerable.UnionBy((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (x, ct) => x.ToString())); + AssertExtensions.Throws("second", () => AsyncEnumerable.UnionBy(AsyncEnumerable.Empty(), null, async (x, ct) => x.Length)); + AssertExtensions.Throws("keySelector", () => AsyncEnumerable.UnionBy(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] first, int[] second) + { + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.UnionBy(second, x => x / 2), + firstSource.UnionBy(secondSource, x => x / 2)); + + await AssertEqual( + first.UnionBy(second, x => x / 2, OddEvenComparer), + firstSource.UnionBy(secondSource, x => x / 2, OddEvenComparer)); + + await AssertEqual( + first.UnionBy(second, x => x / 2), + firstSource.UnionBy(secondSource, async (x, ct) => x / 2)); + + await AssertEqual( + first.UnionBy(second, x => x / 2, OddEvenComparer), + firstSource.UnionBy(secondSource, async (x, ct) => x / 2, OddEvenComparer)); + } + } + } +#endif + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(1, 3, 5); + CancellationTokenSource cts; + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.UnionBy(second, x => x).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(first.UnionBy(second, x => + { + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + + cts = new(); + await Assert.ThrowsAsync(async () => + { + await ConsumeAsync(first.UnionBy(second, async (x, ct) => + { + Assert.Equal(cts.Token, ct); + await Task.Yield(); + cts.Cancel(); + return x; + }).WithCancellation(cts.Token)); + }); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task InterfaceCalls_ExpectedCounts(bool useAsync) + { + TrackingAsyncEnumerable first = CreateSource(2, 4, 8, 16, 32, 64).Track(); + TrackingAsyncEnumerable second = CreateSource(1, 3, 5).Track(); + int funcCount = 0; + await ConsumeAsync(useAsync ? + first.UnionBy(second, async (x, ct) => + { + funcCount++; + return x; + }) : + first.UnionBy(second, x => + { + funcCount++; + return x; + })); + Assert.Equal(7, first.MoveNextAsyncCount); + Assert.Equal(6, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + Assert.Equal(9, funcCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/UnionTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/UnionTests.cs new file mode 100644 index 00000000000000..90219b88e62a4d --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/UnionTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class UnionTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.Union(null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.Union(AsyncEnumerable.Empty(), null)); + } + + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] first, int[] second) + { + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.Union(second), + firstSource.Union(secondSource)); + + await AssertEqual( + second.Union(first), + secondSource.Union(firstSource)); + + await AssertEqual( + first.Union(second, OddEvenComparer), + firstSource.Union(secondSource, OddEvenComparer)); + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(1, 3, 5); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in first.Union(second).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable first, second; + + first = CreateSource(2, 4, 8, 16).Track(); + second = CreateSource(1, 3, 5).Track(); + await ConsumeAsync(first.Union(second)); + + Assert.Equal(5, first.MoveNextAsyncCount); + Assert.Equal(4, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/WhereTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/WhereTests.cs new file mode 100644 index 00000000000000..b64a149b253d64 --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/WhereTests.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class WhereTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("source", () => AsyncEnumerable.Where(null, i => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.Where(null, (i, index) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.Where(null, async (i, ct) => i % 2 == 0)); + AssertExtensions.Throws("source", () => AsyncEnumerable.Where(null, async (i, index, ct) => i % 2 == 0)); + + AssertExtensions.Throws("predicate", () => AsyncEnumerable.Where(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.Where(AsyncEnumerable.Empty(), (Func)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.Where(AsyncEnumerable.Empty(), (Func>)null)); + AssertExtensions.Throws("predicate", () => AsyncEnumerable.Where(AsyncEnumerable.Empty(), (Func>)null)); + } + + [Theory] + [InlineData(new int[0])] + [InlineData(new int[] { 42 })] + [InlineData(new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 })] + [InlineData(new int[] { -1, 1, -2, 2, -10, 10 })] + [InlineData(new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] ints) + { + foreach (IAsyncEnumerable source in CreateSources(ints)) + { + await AssertEqual( + ints.Where(i => i % 2 == 0), + source.Where(i => i % 2 == 0)); + + await AssertEqual( + ints.Where((i, index) => (i + index) % 2 == 0), + source.Where((i, index) => (i + index) % 2 == 0)); + + await AssertEqual( + ints.Where(i => i % 2 == 0), + source.Where(async (int i, CancellationToken ct) => i % 2 == 0)); + + await AssertEqual( + ints.Where((i, index) => (i + index) % 2 == 0), + source.Where(async (i, index, ct) => (i + index) % 2 == 0)); + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + IAsyncEnumerable source = CreateSource(1, 3, 5, 6, 7, 8); + + await Validate(source.Where(i => i % 2 == 0)); + await Validate(source.Where((i, index) => (i + index) % 2 == 0)); + await Validate(source.Where(async (int i, CancellationToken index) => i % 2 == 0)); + await Validate(source.Where(async (i, index, ct) => (i + index) % 2 == 0)); + + static async Task Validate(IAsyncEnumerable source) + { + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach (int item in source.WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + await Validate(source => source.Where(i => i % 2 == 0)); + await Validate(source => source.Where((i, index) => (i + index) % 2 == 0)); + await Validate(source => source.Where(async (int i, CancellationToken cancellationToken) => i % 2 == 0)); + await Validate(source => source.Where(async (i, index, ct) => (i + index) % 2 == 0)); + + async Task Validate(Func, IAsyncEnumerable> factory) + { + TrackingAsyncEnumerable source = CreateSource(1, 2, 3, 4).Track(); + await ConsumeAsync(factory(source)); + Assert.Equal(5, source.MoveNextAsyncCount); + Assert.Equal(4, source.CurrentCount); + Assert.Equal(1, source.DisposeAsyncCount); + } + } + } +} diff --git a/src/libraries/System.Linq.AsyncEnumerable/tests/ZipTests.cs b/src/libraries/System.Linq.AsyncEnumerable/tests/ZipTests.cs new file mode 100644 index 00000000000000..b09730ad3d7e2d --- /dev/null +++ b/src/libraries/System.Linq.AsyncEnumerable/tests/ZipTests.cs @@ -0,0 +1,146 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.Linq.Tests +{ + public class ZipTests : AsyncEnumerableTests + { + [Fact] + public void InvalidInputs_Throws() + { + AssertExtensions.Throws("first", () => AsyncEnumerable.Zip((IAsyncEnumerable)null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), (IAsyncEnumerable)null)); + + AssertExtensions.Throws("first", () => AsyncEnumerable.Zip((IAsyncEnumerable)null, AsyncEnumerable.Empty(), (s, i) => (s, i))); + AssertExtensions.Throws("second", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, (s, i) => (s, i))); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func)null)); + + AssertExtensions.Throws("first", () => AsyncEnumerable.Zip((IAsyncEnumerable)null, AsyncEnumerable.Empty(), async (s, i, ct) => (s, i))); + AssertExtensions.Throws("second", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, async (s, i, ct) => (s, i))); + AssertExtensions.Throws("resultSelector", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (Func>)null)); + + AssertExtensions.Throws("first", () => AsyncEnumerable.Zip((IAsyncEnumerable)null, AsyncEnumerable.Empty(), AsyncEnumerable.Empty())); + AssertExtensions.Throws("second", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), (IAsyncEnumerable)null, AsyncEnumerable.Empty())); + AssertExtensions.Throws("third", () => AsyncEnumerable.Zip(AsyncEnumerable.Empty(), AsyncEnumerable.Empty(), (IAsyncEnumerable)null)); + } + + [Theory] + [InlineData(new int[0], new int[0])] + [InlineData(new int[0], new int[] { 42 })] + [InlineData(new int[] { 42, 43 }, new int[0])] + [InlineData(new int[] { 1 }, new int[] { 2, 3 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 3, 5 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 4, 8 })] + [InlineData(new int[] { 2, 4, 8 }, new int[] { 2, 5, 8 })] + [InlineData(new int[] { -1, 2, 5, 6, 7, 8 }, new int[] { int.MinValue, int.MaxValue })] + public async Task VariousValues_MatchesEnumerable(int[] first, int[] second) + { + foreach (IAsyncEnumerable firstSource in CreateSources(first)) + { + foreach (IAsyncEnumerable secondSource in CreateSources(second)) + { + await AssertEqual( + first.Zip(second, (f, s) => (f, s)), + firstSource.Zip(secondSource, (f, s) => (f, s))); + + await AssertEqual( + first.Zip(second, (f, s) => (f, s)), + firstSource.Zip(secondSource, async (f, s, ct) => (f, s))); + +#if NET + await AssertEqual( + first.Zip(second), + firstSource.Zip(secondSource)); + + await AssertEqual( + first.Zip(second, second), + firstSource.Zip(secondSource, secondSource)); + + await AssertEqual( + first.Zip(second, first), + firstSource.Zip(secondSource, firstSource)); +#endif + } + } + } + + [Fact] + public async Task Cancellation_Cancels() + { + await Validate((first, second) => first.Zip(second, (f, s) => (f, s))); + await Validate((first, second) => first.Zip(second, async (f, s, ct) => (f, s))); +#if NET + await Validate((first, second) => first.Zip(second)); +#endif + + static async Task Validate(Func, IAsyncEnumerable, IAsyncEnumerable<(int, int)>> factory) + { + IAsyncEnumerable first = CreateSource(2, 4, 8, 16); + IAsyncEnumerable second = CreateSource(1, 3, 5); + CancellationTokenSource cts = new(); + await Assert.ThrowsAsync(async () => + { + await foreach ((int, int) item in factory(first, second).WithCancellation(cts.Token)) + { + cts.Cancel(); + } + }); + } + } + + [Fact] + public async Task InterfaceCalls_ExpectedCounts() + { + TrackingAsyncEnumerable first, second, third; + + first = CreateSource(2, 4, 8, 16).Track(); + second = CreateSource(1, 3, 5).Track(); + await ConsumeAsync(first.Zip(second)); + Assert.Equal(4, first.MoveNextAsyncCount); + Assert.Equal(3, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + + first = CreateSource(2, 4, 8, 16).Track(); + second = CreateSource(1, 3, 5).Track(); + await ConsumeAsync(first.Zip(second, (f, s) => (f, s))); + Assert.Equal(4, first.MoveNextAsyncCount); + Assert.Equal(3, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(4, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + + first = CreateSource(1, 3, 5).Track(); + second = CreateSource(2, 4, 8, 16).Track(); + await ConsumeAsync(first.Zip(second, async (f, s, ct) => (f, s))); + Assert.Equal(4, first.MoveNextAsyncCount); + Assert.Equal(3, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(3, second.MoveNextAsyncCount); + Assert.Equal(3, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + + first = CreateSource(1, 3, 5).Track(); + second = CreateSource(2, 4, 8, 16).Track(); + third = CreateSource(42, 84).Track(); + await ConsumeAsync(first.Zip(second, third)); + Assert.Equal(3, first.MoveNextAsyncCount); + Assert.Equal(2, first.CurrentCount); + Assert.Equal(1, first.DisposeAsyncCount); + Assert.Equal(3, second.MoveNextAsyncCount); + Assert.Equal(2, second.CurrentCount); + Assert.Equal(1, second.DisposeAsyncCount); + Assert.Equal(3, third.MoveNextAsyncCount); + Assert.Equal(2, third.CurrentCount); + Assert.Equal(1, third.DisposeAsyncCount); + } + } +} diff --git a/src/libraries/System.Text.Json/tests/Common/JsonTestHelper.cs b/src/libraries/System.Text.Json/tests/Common/JsonTestHelper.cs index 8d96c4e79c432b..7fbdfef6a7a14c 100644 --- a/src/libraries/System.Text.Json/tests/Common/JsonTestHelper.cs +++ b/src/libraries/System.Text.Json/tests/Common/JsonTestHelper.cs @@ -247,16 +247,6 @@ public static IEnumerable CrossJoin( Func resultSelector) => first.CrossJoin(second, third).Select(tuple => resultSelector(tuple.First, tuple.Second, tuple.Third)); - public static async Task> ToListAsync(this IAsyncEnumerable source) - { - var list = new List(); - await foreach (T item in source) - { - list.Add(item); - } - return list; - } - private static readonly Regex s_stripWhitespace = new Regex(@"\s+", RegexOptions.Compiled); public static string StripWhitespace(this string value) diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs index 79528a4a61749d..d54eebb4125752 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/Serialization/CollectionTests.cs @@ -7,6 +7,7 @@ using System.Collections.Immutable; using System.Collections.ObjectModel; using System.Collections.Specialized; +using System.Linq; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Text.Json.Serialization.Tests; diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/System.Text.Json.SourceGeneration.Tests.targets b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/System.Text.Json.SourceGeneration.Tests.targets index 7982d11db5ee04..17f03cf208cc08 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/System.Text.Json.SourceGeneration.Tests.targets +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.SourceGeneration.Tests/System.Text.Json.SourceGeneration.Tests.targets @@ -148,6 +148,7 @@ + diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/Stream.DeserializeAsyncEnumerable.cs b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/Stream.DeserializeAsyncEnumerable.cs index 23c7c5659f8481..fbc1619c6f9832 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/Stream.DeserializeAsyncEnumerable.cs +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/Serialization/Stream.DeserializeAsyncEnumerable.cs @@ -415,16 +415,6 @@ private static JsonTypeInfo ResolveJsonTypeInfo(Type type, JsonSerializerOptions return options.TypeInfoResolver.GetTypeInfo(type, options); } - private static async Task> ToListAsync(this IAsyncEnumerable source) - { - var list = new List(); - await foreach (T item in source) - { - list.Add(item); - } - return list; - } - private sealed class SlowStream(IEnumerable byteSource) : Stream, IDisposable { private readonly IEnumerator _enumerator = byteSource.GetEnumerator(); diff --git a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/System.Text.Json.Tests.csproj b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/System.Text.Json.Tests.csproj index b70e73d3b83254..85cc1a9a57294c 100644 --- a/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/System.Text.Json.Tests.csproj +++ b/src/libraries/System.Text.Json/tests/System.Text.Json.Tests/System.Text.Json.Tests.csproj @@ -269,6 +269,7 @@ + diff --git a/src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowTestHelper.IAsyncEnumerable.cs b/src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowTestHelper.IAsyncEnumerable.cs index 3eb8711a32e723..0b5907e3377f0d 100644 --- a/src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowTestHelper.IAsyncEnumerable.cs +++ b/src/libraries/System.Threading.Tasks.Dataflow/tests/Dataflow/DataflowTestHelper.IAsyncEnumerable.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq; namespace System.Threading.Tasks.Dataflow.Tests { @@ -9,35 +10,4 @@ internal static partial class DataflowTestHelpers { internal static Func> ToAsyncEnumerable = item => AsyncEnumerable.Repeat(item, 1); } - - internal static partial class AsyncEnumerable - { - internal static async IAsyncEnumerable Repeat(int item, int count) - { - for (int i = 0; i < count; i++) - { - await Task.Yield(); - yield return item; - } - } - - internal static async IAsyncEnumerable Range(int start, int count) - { - var end = start + count; - for (int i = start; i < end; i++) - { - await Task.Yield(); - yield return i; - } - } - - internal static async IAsyncEnumerable ToAsyncEnumerable(this IEnumerable enumerable) - { - foreach (T item in enumerable) - { - await Task.Yield(); - yield return item; - } - } - } } diff --git a/src/libraries/System.Threading.Tasks.Dataflow/tests/System.Threading.Tasks.Dataflow.Tests.csproj b/src/libraries/System.Threading.Tasks.Dataflow/tests/System.Threading.Tasks.Dataflow.Tests.csproj index 813213d76d052e..1162f2ea07cea5 100644 --- a/src/libraries/System.Threading.Tasks.Dataflow/tests/System.Threading.Tasks.Dataflow.Tests.csproj +++ b/src/libraries/System.Threading.Tasks.Dataflow/tests/System.Threading.Tasks.Dataflow.Tests.csproj @@ -1,4 +1,4 @@ - + true $(NetCoreAppCurrent);$(NetFrameworkMinimum) @@ -33,5 +33,6 @@ +