Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions src/DependencyInjection/DI/perf/CallSiteValidatorBenchmark.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Runtime.CompilerServices;
using BenchmarkDotNet.Attributes;
using Microsoft.Extensions.DependencyInjection.ServiceLookup;

namespace Microsoft.Extensions.DependencyInjection.Performance
{
public class CallSiteValidatorBenchmark
{
private ServiceCallSite _callSite;

[GlobalSetup]
public void Setup()
{
var services = new ServiceCollection();
services.AddTransient<A>();
services.AddTransient<B>();
services.AddTransient<C>();
services.AddTransient<D>();
services.AddTransient<E>();
services.AddTransient<F>();
services.AddTransient<G>();
services.AddTransient<H>();
services.AddTransient<I>();
services.AddTransient<J>();
services.AddTransient<K>();
services.AddTransient<L>();
services.AddTransient<M>();
services.AddTransient<N>();
services.AddTransient<O>();
services.AddTransient<P>();

var callSiteFactory = new CallSiteFactory(services.ToArray());

_callSite = callSiteFactory.GetCallSite(typeof(A), new CallSiteChain());
}

[Benchmark()]
public void ValidateCallSite()
{
var callSiteValidator = new CallSiteValidator();

callSiteValidator.ValidateCallSite(_callSite);
}

private class A
{
public A(B b, C c, D d, E e, F f, G g, H h, I i, J j, K k, L l)
{

}
}

private class B
{
public B(C c, D d, E e, F f, G g, H h, I i, J j, K k, L l)
{

}
}

private class C
{
public C(D d, E e, F f, G g, H h, I i, J j, K k, L l)
{

}

}

private class D
{
public D(E e, F f, G g, H h, I i, J j, K k, L l)
{

}
}

private class E
{
public E(F f, G g, H h, I i, J j, K k, L l)
{

}
}

private class F
{
public F(G g, H h, I i, J j, K k, L l)
{

}
}

private class G
{
public G(H h, I i, J j, K k, L l)
{

}
}

private class H
{
public H(I i, J j, K k, L l)
{

}
}

private class I
{
public I(J j, K k, L l)
{

}
}

private class J
{
public J(K k, L l)
{

}
}

private class K
{
public K(L l)
{

}
}

private class L
{
public L(M m)
{

}
}

private class M
{
public M(N n)
{

}
}

private class N
{
public N(O o)
{

}
}

private class O
{
public O(P p)
{

}
}

private class P
{
public P()
{

}
}

}
}
73 changes: 68 additions & 5 deletions src/DependencyInjection/DI/src/ServiceLookup/CallSiteValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;

namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
Expand All @@ -11,13 +12,69 @@ internal class CallSiteValidator: CallSiteVisitor<CallSiteValidator.CallSiteVali
// Keys are services being resolved via GetService, values - first scoped service in their call site tree
private readonly ConcurrentDictionary<Type, Type> _scopedServices = new ConcurrentDictionary<Type, Type>();

// Cache already-checked services that resulted in null.
private readonly HashSet<Type> _nonScopedServices = new HashSet<Type>();
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just be using ServiceCacheKey instead? Same for _scoped services?


public void ValidateCallSite(ServiceCallSite callSite)
{
var scoped = VisitCallSite(callSite, default);
VisitCallSite(callSite, default);
}

protected override Type VisitCallSite(ServiceCallSite callSite, CallSiteValidatorState argument)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - rename argument to state.

{
Type scoped;
bool ignoreServiceType = argument.IgnoreServiceType;

if ((!ignoreServiceType && _scopedServices.TryGetValue(callSite.ServiceType, out scoped)) ||
(callSite.ImplementationType != null && _scopedServices.TryGetValue(callSite.ImplementationType, out scoped)))
{
return scoped;
}
else
{
lock (_nonScopedServices)
{
if ((!ignoreServiceType && _nonScopedServices.Contains(callSite.ServiceType)) ||
(callSite.ImplementationType != null && _nonScopedServices.Contains(callSite.ImplementationType)))
{
return null;
}
}
}

argument.IgnoreServiceType = false;

scoped = base.VisitCallSite(callSite, argument);

if (scoped != null)
{
_scopedServices[callSite.ServiceType] = scoped;
if (!ignoreServiceType)
{
_scopedServices[callSite.ServiceType] = scoped;
}

if (callSite.ImplementationType != null)
{
_scopedServices[callSite.ImplementationType] = scoped;
}
}
else
{
lock (_nonScopedServices)
{
if (!ignoreServiceType)
{
_nonScopedServices.Add(callSite.ServiceType);
}

if (callSite.ImplementationType != null)
{
_nonScopedServices.Add(callSite.ImplementationType);
}
}
}

return scoped;
}

public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope)
Expand Down Expand Up @@ -58,9 +115,14 @@ protected override Type VisitIEnumerable(IEnumerableCallSite enumerableCallSite,
CallSiteValidatorState state)
{
Type result = null;
foreach (var serviceCallSite in enumerableCallSite.ServiceCallSites)
ServiceCallSite[] serviceCallSites = enumerableCallSite.ServiceCallSites;

for (int i = 0; i < serviceCallSites.Length; i++)
{
var scoped = VisitCallSite(serviceCallSite, state);
// Ignore service type for all except the last element
state.IgnoreServiceType = i != serviceCallSites.Length - 1;

var scoped = VisitCallSite(serviceCallSites[i], state);
if (result == null)
{
result = scoped;
Expand Down Expand Up @@ -107,6 +169,7 @@ protected override Type VisitScopeCache(ServiceCallSite scopedCallSite, CallSite
internal struct CallSiteValidatorState
{
public ServiceCallSite Singleton { get; set; }
public bool IgnoreServiceType { get; set; }
}
}
}
}
60 changes: 60 additions & 0 deletions src/DependencyInjection/DI/test/ServiceProviderValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
using Xunit;

Expand Down Expand Up @@ -110,6 +111,65 @@ public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton()
Assert.NotNull(result);
}

[Fact]
public void GetService_DoesNotThrow_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsNotScoped()
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddScoped<IBar, Bar>();
serviceCollection.AddSingleton<IBar, Bar2>();
serviceCollection.AddSingleton<IBaz, Baz>();
var serviceProvider = serviceCollection.BuildServiceProvider(true);


// Act + Assert
var exception = Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService(typeof(IEnumerable<IBar>)));
Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable<IBar>)}' from root provider.", exception.Message);

var result = serviceProvider.GetService(typeof(IBar));
Assert.NotNull(result);
}


[Fact]
public void GetService_Throws_WhenGetServiceForServiceWithMultipleImplementationScopesWhereLastIsScoped()
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IBar, Bar>();
serviceCollection.AddScoped<IBar, Bar2>();
serviceCollection.AddSingleton<IBaz, Baz>();
var serviceProvider = serviceCollection.BuildServiceProvider(true);


// Act + Assert
var exception = Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService(typeof(IEnumerable<IBar>)));
Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable<IBar>)}' from root provider.", exception.Message);

exception = Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService(typeof(IBar)));
Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message);
}

[Fact]
public void GetService_DoesNotThrow_WhenGetServiceForNonScopedImplementationWithMultipleImplementationScopesWhereLastIsScoped()
{
// Arrange
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IBar, Bar>();
serviceCollection.AddSingleton<Bar>();
serviceCollection.AddScoped<IBar, Bar2>();
serviceCollection.AddSingleton<IBaz, Baz>();
var serviceProvider = serviceCollection.BuildServiceProvider(true);


// Act + Assert
var exception = Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService(typeof(IEnumerable<IBar>)));
Assert.Equal($"Cannot resolve scoped service '{typeof(IEnumerable<IBar>)}' from root provider.", exception.Message);

var result = serviceProvider.GetService(typeof(Bar));
Assert.NotNull(result);
}

[Fact]
public void BuildServiceProvider_ValidateOnBuild_ThrowsForUnresolvableServices()
{
Expand Down