From 05618f521070ccc7ecc1aa4aaeaa9000b72e7109 Mon Sep 17 00:00:00 2001 From: Brandon Dahler Date: Sat, 21 Sep 2019 01:57:25 -0400 Subject: [PATCH 1/3] Cache CallSiteValidator results, turning O(n!) become O(n) --- .../DI/perf/CallSiteValidatorBenchmark.cs | 186 ++++++++++++++++++ .../DI/src/ServiceLookup/CallSiteValidator.cs | 68 ++++++- .../DI/test/ServiceProviderValidationTests.cs | 61 ++++++ 3 files changed, 310 insertions(+), 5 deletions(-) create mode 100644 src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs diff --git a/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs b/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs new file mode 100644 index 00000000000..be9852b13d7 --- /dev/null +++ b/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs @@ -0,0 +1,186 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Runtime.CompilerServices; +using BenchmarkDotNet.Attributes; +using Microsoft.Extensions.DependencyInjection.ServiceLookup; + +namespace Microsoft.Extensions.DependencyInjection.Performance +{ + public class CallSiteValidatorBenchmark + { + private ServiceCallSite _callSite; + + [GlobalSetup] + public void Setup() + { + var services = new ServiceCollection(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient(); + services.AddTransient

(); + + var callSiteFactory = new CallSiteFactory(services.ToArray()); + + _callSite = callSiteFactory.GetCallSite(typeof(A), new CallSiteChain()); + } + + [Benchmark()] + public void ValidateCallSite() + { + var callSiteValidator = new CallSiteValidator(); + + callSiteValidator.ValidateCallSite(_callSite); + } + + private class A + { + public A(B b, C c, D d, E e, F f, G g, H h, I i, J j, K k, L l) + { + + } + + [MethodImpl(MethodImplOptions.NoInlining)] + public void Foo() + { + + } + } + + private class B + { + public B(C c, D d, E e, F f, G g, H h, I i, J j, K k, L l) + { + + } + } + + private class C + { + public C(D d, E e, F f, G g, H h, I i, J j, K k, L l) + { + + } + + } + + private class D + { + public D(E e, F f, G g, H h, I i, J j, K k, L l) + { + + } + } + + private class E + { + public E(F f, G g, H h, I i, J j, K k, L l) + { + + } + } + + private class F + { + public F(G g, H h, I i, J j, K k, L l) + { + + } + } + + private class G + { + public G(H h, I i, J j, K k, L l) + { + + } + } + + private class H + { + public H(I i, J j, K k, L l) + { + + } + } + + private class I + { + public I(J j, K k, L l) + { + + } + } + + private class J + { + public J(K k, L l) + { + + } + } + + private class K + { + public K(L l) + { + + } + } + + private class L + { + public L(M m) + { + + } + } + + private class M + { + public M(N n) + { + + } + } + + private class N + { + public N(O o) + { + + } + } + + private class O + { + public O(P p) + { + + } + } + + private class P + { + public P() + { + + } + } + + } +} diff --git a/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs b/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs index 42ba65fbbf4..5486dd746c9 100644 --- a/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs +++ b/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; namespace Microsoft.Extensions.DependencyInjection.ServiceLookup { @@ -11,15 +12,66 @@ internal class CallSiteValidator: CallSiteVisitor _scopedServices = new ConcurrentDictionary(); + // Cache already-checked services that resulted in null. + private readonly HashSet _nonScopedServices = new HashSet(); + public void ValidateCallSite(ServiceCallSite callSite) { - var scoped = VisitCallSite(callSite, default); + VisitCallSite(callSite, default); + } + + protected override Type VisitCallSite(ServiceCallSite callSite, CallSiteValidatorState argument) + { + Type scoped; + bool ignoreServiceType = argument.IgnoreServiceType; + + if ((!ignoreServiceType && _scopedServices.TryGetValue(callSite.ServiceType, out scoped)) || + _scopedServices.TryGetValue(callSite.ImplementationType, out scoped)) + { + return scoped; + } + else + { + lock (_nonScopedServices) + { + if ((!ignoreServiceType && _nonScopedServices.Contains(callSite.ServiceType)) || + _nonScopedServices.Contains(callSite.ImplementationType)) + { + return null; + } + } + } + + argument.IgnoreServiceType = false; + + scoped = base.VisitCallSite(callSite, argument); + if (scoped != null) { - _scopedServices[callSite.ServiceType] = scoped; + _scopedServices[callSite.ImplementationType] = scoped; + + if (!ignoreServiceType) + { + _scopedServices[callSite.ServiceType] = scoped; + } } + else + { + lock (_nonScopedServices) + { + _nonScopedServices.Add(callSite.ImplementationType); + + if (!ignoreServiceType) + { + _nonScopedServices.Add(callSite.ServiceType); + } + } + } + + return scoped; } + public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope) { if (ReferenceEquals(scope, rootScope) @@ -58,9 +110,14 @@ protected override Type VisitIEnumerable(IEnumerableCallSite enumerableCallSite, CallSiteValidatorState state) { Type result = null; - foreach (var serviceCallSite in enumerableCallSite.ServiceCallSites) + ServiceCallSite[] serviceCallSites = enumerableCallSite.ServiceCallSites; + + for (int i = 0; i < serviceCallSites.Length; i++) { - var scoped = VisitCallSite(serviceCallSite, state); + // Ignore service type for all except the last element + state.IgnoreServiceType = i != serviceCallSites.Length - 1; + + var scoped = VisitCallSite(serviceCallSites[i], state); if (result == null) { result = scoped; @@ -107,6 +164,7 @@ protected override Type VisitScopeCache(ServiceCallSite scopedCallSite, CallSite internal struct CallSiteValidatorState { public ServiceCallSite Singleton { get; set; } + public bool IgnoreServiceType { get; set; } } } -} \ No newline at end of file +} diff --git a/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs b/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs index 613370240cd..8605486b721 100644 --- a/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs +++ b/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using Microsoft.Extensions.DependencyInjection.Specification.Fakes; using Xunit; @@ -110,6 +111,66 @@ public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() Assert.NotNull(result); } + + [Fact] + public void GetService_DoesNotThrow_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsNotScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + var result = serviceProvider.GetService(typeof(IBar)); + Assert.NotNull(result); + } + + + [Fact] + public void GetService_Throws_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + exception = Assert.Throws(() => serviceProvider.GetService(typeof(IBar))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message); + } + + [Fact] + public void GetService_DoesNotThrow_WhenGetServiceForNonScopedImplementationWithMultipleImplementationScopesWhereLastIsScoped() + { + // Arrange + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + serviceCollection.AddSingleton(); + serviceCollection.AddScoped(); + serviceCollection.AddSingleton(); + var serviceProvider = serviceCollection.BuildServiceProvider(true); + + + // Act + Assert + var exception = Assert.Throws(() => serviceProvider.GetService(typeof(IEnumerable))); + Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable)}' from root provider.", exception.Message); + + var result = serviceProvider.GetService(typeof(Bar)); + Assert.NotNull(result); + } + [Fact] public void BuildServiceProvider_ValidateOnBuild_ThrowsForUnresolvableServices() { From e77fc2ae9e01c767322e8453974bd401c655a763 Mon Sep 17 00:00:00 2001 From: Brandon Dahler Date: Sat, 21 Sep 2019 03:15:43 -0400 Subject: [PATCH 2/3] Nits --- .../DI/perf/CallSiteValidatorBenchmark.cs | 6 ------ .../DI/src/ServiceLookup/CallSiteValidator.cs | 1 - .../DI/test/ServiceProviderValidationTests.cs | 1 - 3 files changed, 8 deletions(-) diff --git a/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs b/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs index be9852b13d7..23e79750c81 100644 --- a/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs +++ b/src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs @@ -53,12 +53,6 @@ public A(B b, C c, D d, E e, F f, G g, H h, I i, J j, K k, L l) { } - - [MethodImpl(MethodImplOptions.NoInlining)] - public void Foo() - { - - } } private class B diff --git a/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs b/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs index 5486dd746c9..6dbc2dc6052 100644 --- a/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs +++ b/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs @@ -71,7 +71,6 @@ protected override Type VisitCallSite(ServiceCallSite callSite, CallSiteValidato return scoped; } - public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope) { if (ReferenceEquals(scope, rootScope) diff --git a/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs b/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs index 8605486b721..2ee4e13f82c 100644 --- a/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs +++ b/src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs @@ -111,7 +111,6 @@ public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton() Assert.NotNull(result); } - [Fact] public void GetService_DoesNotThrow_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsNotScoped() { From 6fcb9c3616b17bb740a7dfb43971146de97293cd Mon Sep 17 00:00:00 2001 From: Brandon Dahler Date: Sat, 21 Sep 2019 03:45:01 -0400 Subject: [PATCH 3/3] Handle FactoryCallSite, where ImplementationType == null. --- .../DI/src/ServiceLookup/CallSiteValidator.cs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs b/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs index 6dbc2dc6052..a99688e9fd6 100644 --- a/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs +++ b/src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs @@ -26,7 +26,7 @@ protected override Type VisitCallSite(ServiceCallSite callSite, CallSiteValidato bool ignoreServiceType = argument.IgnoreServiceType; if ((!ignoreServiceType && _scopedServices.TryGetValue(callSite.ServiceType, out scoped)) || - _scopedServices.TryGetValue(callSite.ImplementationType, out scoped)) + (callSite.ImplementationType != null && _scopedServices.TryGetValue(callSite.ImplementationType, out scoped))) { return scoped; } @@ -35,7 +35,7 @@ protected override Type VisitCallSite(ServiceCallSite callSite, CallSiteValidato lock (_nonScopedServices) { if ((!ignoreServiceType && _nonScopedServices.Contains(callSite.ServiceType)) || - _nonScopedServices.Contains(callSite.ImplementationType)) + (callSite.ImplementationType != null && _nonScopedServices.Contains(callSite.ImplementationType))) { return null; } @@ -48,23 +48,29 @@ protected override Type VisitCallSite(ServiceCallSite callSite, CallSiteValidato if (scoped != null) { - _scopedServices[callSite.ImplementationType] = scoped; - if (!ignoreServiceType) { _scopedServices[callSite.ServiceType] = scoped; } + + if (callSite.ImplementationType != null) + { + _scopedServices[callSite.ImplementationType] = scoped; + } } else { lock (_nonScopedServices) { - _nonScopedServices.Add(callSite.ImplementationType); - if (!ignoreServiceType) { _nonScopedServices.Add(callSite.ServiceType); } + + if (callSite.ImplementationType != null) + { + _nonScopedServices.Add(callSite.ImplementationType); + } } }