diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs index e272c8a3d722b9..433b53b5cebe6b 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs @@ -10,21 +10,15 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup internal sealed class CallSiteValidator : CallSiteVisitor { // Keys are services being resolved via GetService, values - first scoped service in their call site tree - private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _scopedServices = new ConcurrentDictionary(); - public void ValidateCallSite(ServiceCallSite callSite) - { - Type? scoped = VisitCallSite(callSite, default); - if (scoped != null) - { - _scopedServices[callSite.Cache.Key] = scoped; - } - } + public void ValidateCallSite(ServiceCallSite callSite) => VisitCallSite(callSite, default); public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope) { if (ReferenceEquals(scope, rootScope) - && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService)) + && _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService) + && scopedService != null) { Type serviceType = callSite.ServiceType; if (serviceType == scopedService) @@ -42,6 +36,25 @@ public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IS } } + protected override Type? VisitCallSite(ServiceCallSite callSite, CallSiteValidatorState argument) + { + // First, check if we have encountered this call site before to prevent visiting call site trees that have already been visited + // If firstScopedServiceInCallSiteTree is null there are no scoped dependencies in this service's call site tree + // If firstScopedServiceInCallSiteTree has a value, it contains the first scoped service in this service's call site tree + if (_scopedServices.TryGetValue(callSite.Cache.Key, out Type? firstScopedServiceInCallSiteTree)) + { + return firstScopedServiceInCallSiteTree; + } + + // Walk the tree + Type? scoped = base.VisitCallSite(callSite, argument); + + // Store the result for each visited service + _scopedServices[callSite.Cache.Key] = scoped; + + return scoped; + } + protected override Type? VisitConstructor(ConstructorCallSite constructorCallSite, CallSiteValidatorState state) { Type? result = null; diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs index 8780312c2e8ff6..a5ee249bd4705e 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs @@ -180,6 +180,65 @@ public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() Assert.NotNull(result); } + [Fact] + public void GetService_DoesNotThrow_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsNotScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + var result = serviceProvider.GetService(typeof(IBar)); + Assert.NotNull(result); + } + + + [Fact] + public void GetService_Throws_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + exception = Assert.Throws(() => serviceProvider.GetService(typeof(IBar))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message); + } + + [Fact] + public void GetService_DoesNotThrow_WhenGetServiceForNonScopedImplementationWithMultipleImplementationScopesWhereLastIsScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + var result = serviceProvider.GetService(typeof(Bar)); + Assert.NotNull(result); + } + [Fact] public void BuildServiceProvider_ValidateOnBuild_ThrowsForUnresolvableServices() {