Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 220 additions & 2 deletions src/GraphQL.EntityFramework/GraphApi/EfGraphQLService_Navigation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ public FieldBuilder<TSource, TReturn> AddNavigationField<TSource, TReturn, TProj

var compiledProjection = projection.Compile();

// Get filter-required navigation paths at setup time for reloading if needed
var filterRequiredNavPaths = GetFilterRequiredNavPathsForReload<TReturn>();

field.Resolver = new FuncFieldResolver<TSource, TReturn?>(
async context =>
{
Expand Down Expand Up @@ -119,8 +122,21 @@ public FieldBuilder<TSource, TReturn> AddNavigationField<TSource, TReturn, TProj
exception);
}

if (fieldContext.Filters == null ||
await fieldContext.Filters.ShouldInclude(context.UserContext, fieldContext.DbContext, context.User, result))
if (fieldContext.Filters == null)
{
return result;
}

// If filter requires navigation properties, reload the entity with those includes
if (result != null && filterRequiredNavPaths.Count > 0)
{
result = await ReloadWithFilterNavigations(
fieldContext.DbContext,
result,
filterRequiredNavPaths);
}

if (await fieldContext.Filters.ShouldInclude(context.UserContext, fieldContext.DbContext, context.User, result))
{
return result;
}
Expand Down Expand Up @@ -158,4 +174,206 @@ public FieldBuilder<TSource, TReturn> AddNavigationField<TSource, TReturn>(
graph.AddField(field);
return new FieldBuilderEx<TSource, TReturn>(field);
}

/// <summary>
/// Gets the navigation paths required by filters for reloading entities.
/// Returns just the navigation parts (not prefixed with field name).
/// </summary>
IReadOnlyList<string> GetFilterRequiredNavPathsForReload<TReturn>()
where TReturn : class
{
var filters = resolveFilters?.Invoke(null!);
if (filters == null)
{
return [];
}

var requiredProps = filters.GetRequiredFilterProperties<TReturn>();
var navigationPaths = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

foreach (var prop in requiredProps)
{
var lastDot = prop.LastIndexOf('.');
if (lastDot > 0)
{
// e.g., "TravelRequest.GroupOwnerId" -> "TravelRequest"
var navPath = prop[..lastDot];
navigationPaths.Add(navPath);
}
}

return [.. navigationPaths];
}

/// <summary>
/// Reloads an entity from the database with the specified navigation properties included.
/// </summary>
static async Task<TReturn?> ReloadWithFilterNavigations<TReturn>(
TDbContext dbContext,
TReturn entity,
IReadOnlyList<string> navigationPaths)
where TReturn : class
{
if (navigationPaths.Count == 0)
{
return entity;
}

// Get the entity's primary key
var entityType = dbContext.Model.FindEntityType(typeof(TReturn));
if (entityType == null)
{
return entity;
}

var primaryKey = entityType.FindPrimaryKey();
if (primaryKey == null)
{
return entity;
}

// Get the key values from the entity
var keyValues = primaryKey.Properties
.Select(p => p.PropertyInfo?.GetValue(entity) ?? p.FieldInfo?.GetValue(entity))
.ToArray();

if (keyValues.Any(v => v == null))
{
return entity;
}

// Build a query with includes
IQueryable<TReturn> query = dbContext.Set<TReturn>();

foreach (var navPath in navigationPaths)
{
query = query.Include(navPath);
}

// Filter by primary key
var keyProperties = primaryKey.Properties.ToList();
if (keyProperties.Count == 1)
{
// Single key - use simple Find-like behavior with includes
var keyProperty = keyProperties[0];
var parameter = Expression.Parameter(typeof(TReturn), "e");
var propertyAccess = Expression.Property(parameter, keyProperty.PropertyInfo!);
var constant = Expression.Constant(keyValues[0]);
var equals = Expression.Equal(propertyAccess, constant);
var lambda = Expression.Lambda<Func<TReturn, bool>>(equals, parameter);

return await query.FirstOrDefaultAsync(lambda);
}

// Composite key - need to build combined predicate
var param = Expression.Parameter(typeof(TReturn), "e");
Expression? predicate = null;

for (var i = 0; i < keyProperties.Count; i++)
{
var keyProperty = keyProperties[i];
var propertyAccess = Expression.Property(param, keyProperty.PropertyInfo!);
var constant = Expression.Constant(keyValues[i]);
var equals = Expression.Equal(propertyAccess, constant);

predicate = predicate == null ? equals : Expression.AndAlso(predicate, equals);
}

var lambdaExpr = Expression.Lambda<Func<TReturn, bool>>(predicate!, param);
return await query.FirstOrDefaultAsync(lambdaExpr);
}

/// <summary>
/// Batch reloads multiple entities from the database with the specified navigation properties included.
/// Uses a single query with WHERE Id IN (...) instead of N+1 queries.
/// </summary>
static async Task<IReadOnlyList<TReturn>> BatchReloadWithFilterNavigations<TReturn>(
TDbContext dbContext,
IEnumerable<TReturn> entities,
IReadOnlyList<string> navigationPaths)
where TReturn : class
{
var entityList = entities.ToList();
if (entityList.Count == 0 || navigationPaths.Count == 0)
{
return entityList;
}

// Get the entity's primary key metadata
var entityType = dbContext.Model.FindEntityType(typeof(TReturn));
if (entityType == null)
{
return entityList;
}

var primaryKey = entityType.FindPrimaryKey();
if (primaryKey == null)
{
return entityList;
}

var keyProperties = primaryKey.Properties.ToList();
if (keyProperties.Count != 1)
{
// For composite keys, fall back to individual reloads
var results = new List<TReturn>();
foreach (var entity in entityList)
{
var reloaded = await ReloadWithFilterNavigations(dbContext, entity, navigationPaths);
if (reloaded != null)
{
results.Add(reloaded);
}
}
return results;
}

// Single key - can use IN clause
var keyProperty = keyProperties[0];
var keyValues = entityList
.Select(e => keyProperty.PropertyInfo?.GetValue(e) ?? keyProperty.FieldInfo?.GetValue(e))
.Where(v => v != null)
.ToList();

if (keyValues.Count == 0)
{
return entityList;
}

// Build a query with includes
IQueryable<TReturn> query = dbContext.Set<TReturn>();

foreach (var navPath in navigationPaths)
{
query = query.Include(navPath);
}

// Build WHERE Id IN (...) predicate
var parameter = Expression.Parameter(typeof(TReturn), "e");
var propertyAccess = Expression.Property(parameter, keyProperty.PropertyInfo!);

// Create a list of the key type and use Contains
var keyType = keyProperty.ClrType;
var typedKeyValues = typeof(Enumerable)
.GetMethod("Cast")!
.MakeGenericMethod(keyType)
.Invoke(null, [keyValues])!;
var keyList = typeof(Enumerable)
.GetMethod("ToList")!
.MakeGenericMethod(keyType)
.Invoke(null, [typedKeyValues])!;

var containsMethod = typeof(List<>)
.MakeGenericType(keyType)
.GetMethod("Contains", [keyType])!;

var containsCall = Expression.Call(
Expression.Constant(keyList),
containsMethod,
propertyAccess);

var lambda = Expression.Lambda<Func<TReturn, bool>>(containsCall, parameter);

return await query.Where(lambda).ToListAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ public FieldBuilder<TSource, TReturn> AddNavigationListField<TSource, TReturn, T

var compiledProjection = projection.Compile();

// Get filter-required navigation paths at setup time for reloading if needed
var filterRequiredNavPaths = GetFilterRequiredNavPathsForReload<TReturn>();

field.Resolver = new FuncFieldResolver<TSource, IEnumerable<TReturn>>(async context =>
{
var fieldContext = BuildContext(context);
Expand Down Expand Up @@ -104,6 +107,15 @@ public FieldBuilder<TSource, TReturn> AddNavigationListField<TSource, TReturn, T
return result;
}

// If filter requires navigation properties, batch reload items with those includes
if (filterRequiredNavPaths.Count > 0)
{
result = await BatchReloadWithFilterNavigations(
fieldContext.DbContext,
result,
filterRequiredNavPaths);
}

return await fieldContext.Filters.ApplyFilter(result, context.UserContext, fieldContext.DbContext, context.User);
});

Expand Down
88 changes: 30 additions & 58 deletions src/GraphQL.EntityFramework/IncludeAppender.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,74 +57,46 @@ internal IQueryable<TItem> AddIncludesWithFiltersAndDetectNavigations<TItem>(
/// <summary>
/// Checks if the query is the result of a projection (Select).
/// When a query has been projected via Select, Include cannot be applied.
/// We detect this by examining if the expression tree contains a Select call.
/// We detect this by walking the LINQ method call chain (not lambda bodies).
/// </summary>
static bool IsProjectedQuery<TItem>(IQueryable<TItem> query)
where TItem : class =>
ContainsSelectMethod(query.Expression);
HasSelectInQueryChain(query.Expression);

static bool ContainsSelectMethod(Expression expression)
/// <summary>
/// Walks the LINQ method call chain to find Select calls.
/// Only checks the source argument of LINQ operators, not lambda bodies.
/// </summary>
static bool HasSelectInQueryChain(Expression expression)
{
while (true)
while (expression is MethodCallExpression methodCall)
{
switch (expression)
// Check if this is a LINQ Select call
if (methodCall.Method.Name == "Select" &&
methodCall.Method.DeclaringType is { } declaringType &&
(declaringType == typeof(Queryable) || declaringType == typeof(Enumerable)))
{
case MethodCallExpression methodCall:
// Check if this is a Select call
if (methodCall.Method.Name == "Select")
{
return true;
}

// Check all arguments recursively
foreach (var arg in methodCall.Arguments)
{
if (ContainsSelectMethod(arg))
{
return true;
}
}

// Check the object (for instance method calls)
return methodCall.Object != null && ContainsSelectMethod(methodCall.Object);

case UnaryExpression unary:
expression = unary.Operand;
continue;

case LambdaExpression lambda:
expression = lambda.Body;
continue;

case MemberExpression member:
return member.Expression != null && ContainsSelectMethod(member.Expression);

case BinaryExpression binary:
return ContainsSelectMethod(binary.Left) || ContainsSelectMethod(binary.Right);

case ConditionalExpression conditional:
return ContainsSelectMethod(conditional.Test) || ContainsSelectMethod(conditional.IfTrue) || ContainsSelectMethod(conditional.IfFalse);

case InvocationExpression invocation:
if (ContainsSelectMethod(invocation.Expression))
{
return true;
}

foreach (var arg in invocation.Arguments)
{
if (ContainsSelectMethod(arg))
{
return true;
}
}

return false;
return true;
}

default:
return false;
// For LINQ extension methods, the source is the first argument
// Move to the source queryable to continue walking the chain
if (methodCall.Arguments.Count > 0)
{
expression = methodCall.Arguments[0];
}
else if (methodCall.Object != null)
{
// For instance methods, check the object
expression = methodCall.Object;
}
else
{
break;
}
}

return false;
}

IQueryable<TItem> AddFilterNavigationIncludes<TItem>(
Expand Down
Loading