From 31276109d0bbd4ba17ba199755661ebdd97b0b5f Mon Sep 17 00:00:00 2001 From: Christiaan de Ridder Date: Thu, 21 Dec 2023 15:00:21 +0100 Subject: [PATCH 1/3] Cache all services encountered during scope validation to speed up validate on build --- .../src/ServiceLookup/CallSiteValidator.cs | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index e272c8a3d722b9..e377662675e05a 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -10,21 +10,15 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup internal sealed class CallSiteValidator : CallSiteVisitor { // Keys are services being resolved via GetService, values - first scoped service in their call site tree - private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); - public void ValidateCallSite(ServiceCallSite callSite) - { - Type? scoped = VisitCallSite(callSite, default); - if (scoped != null) - { - _scopedServices[callSite.Cache.Key] = scoped; - } - } + public void ValidateCallSite(ServiceCallSite callSite) => VisitCallSite(callSite, default); public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope) { if (ReferenceEquals(scope, rootScope) - && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService)) + && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService) + && scopedService != null) { Type serviceType = callSite.ServiceType; if (serviceType == scopedService) @@ -42,6 +36,25 @@ public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IS } } + protected override Type? VisitCallSite(ServiceCallSite callSite, CallSiteValidatorState argument) + { + // First, check if we have encountered this call site before to prevent visiting call site trees that have already been visited + // If firstScopedServiceInCallSiteTree is null there are no scoped dependencies in this service's call site tree + // If firstScopedServiceInCallSiteTree has a value, it contains the first scoped service in this service's call site truee + if (_scopedServices.TryGetValue(callSite.Cache.Key, out Type? firstScopedServiceInCallSiteTree)) + { + return firstScopedServiceInCallSiteTree; + } + + // Walk the tree + Type? scoped = base.VisitCallSite(callSite, argument); + + // Store the result for each visited service + _scopedServices[callSite.Cache.Key] = scoped; + + return scoped; + } + protected override Type? VisitConstructor(ConstructorCallSite constructorCallSite, CallSiteValidatorState state) { Type? result = null; From 0035946b685ef112f9ad6d52e475bcf76f5793bd Mon Sep 17 00:00:00 2001 From: Christiaan de Ridder Date: Thu, 21 Dec 2023 15:00:43 +0100 Subject: [PATCH 2/3] Fix typo --- .../src/ServiceLookup/CallSiteValidator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index e377662675e05a..433b53b5cebe6b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -40,7 +40,7 @@ public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IS { // First, check if we have encountered this call site before to prevent visiting call site trees that have already been visited // If firstScopedServiceInCallSiteTree is null there are no scoped dependencies in this service's call site tree - // If firstScopedServiceInCallSiteTree has a value, it contains the first scoped service in this service's call site truee + // If firstScopedServiceInCallSiteTree has a value, it contains the first scoped service in this service's call site tree if (_scopedServices.TryGetValue(callSite.Cache.Key, out Type? firstScopedServiceInCallSiteTree)) { return firstScopedServiceInCallSiteTree; From 772d865984b29a85e782d26e2f8ca8cb9d24943e Mon Sep 17 00:00:00 2001 From: Christiaan de Ridder Date: Fri, 2 Feb 2024 10:02:44 +0100 Subject: [PATCH 3/3] Add unit tests added in original PR --- .../ServiceProviderValidationTests.cs | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index 8780312c2e8ff6..a5ee249bd4705e 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -180,6 +180,65 @@ public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() Assert.NotNull(result); } + [Fact] + public void GetService_DoesNotThrow_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsNotScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + var result = serviceProvider.GetService(typeof(IBar)); + Assert.NotNull(result); + } + + + [Fact] + public void GetService_Throws_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + exception = Assert.Throws(() => serviceProvider.GetService(typeof(IBar))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message); + } + + [Fact] + public void GetService_DoesNotThrow_WhenGetServiceForNonScopedImplementationWithMultipleImplementationScopesWhereLastIsScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + var result = serviceProvider.GetService(typeof(Bar)); + Assert.NotNull(result); + } + [Fact] public void BuildServiceProvider_ValidateOnBuild_ThrowsForUnresolvableServices() {