Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 187 additions & 4 deletions src/GraphQL.EntityFramework/Filters/FilterEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ class FilterEntry<TDbContext, TEntity, TProjection> : IFilterEntry<TDbContext>
{
Func<object, TDbContext, ClaimsPrincipal?, TProjection, Task<bool>> filter;
Func<object, TProjection>? compiledProjection;
IReadOnlySet<string> requiredPropertyNames;

public FilterEntry(
Func<object, TDbContext, ClaimsPrincipal?, TProjection, Task<bool>> filter,
Expand All @@ -13,18 +14,200 @@ public FilterEntry(
if (projection is null)
{
compiledProjection = null;
RequiredPropertyNames = new HashSet<string>();
requiredPropertyNames = new HashSet<string>();
}
else
{
var compiled = projection.Compile();
compiledProjection = entity => compiled((TEntity)entity);
RequiredPropertyNames = ProjectionAnalyzer.ExtractRequiredProperties(projection);
ValidateProjectionCompatibility(projection, RequiredPropertyNames);
requiredPropertyNames = ProjectionAnalyzer.ExtractRequiredProperties(projection);
ValidateProjectionCompatibility(projection, requiredPropertyNames);
}
}

public IReadOnlySet<string> RequiredPropertyNames { get; }
public IReadOnlySet<string> RequiredPropertyNames => requiredPropertyNames;

public FieldProjectionInfo AddRequirements(
FieldProjectionInfo projection,
IReadOnlyDictionary<string, Navigation>? navigationProperties)
{
if (requiredPropertyNames.Count == 0)
{
return projection;
}

// Separate simple fields and navigation paths
var scalarFieldsToAdd = new List<string>();
var navigationPaths = new Dictionary<string, HashSet<string>>(StringComparer.OrdinalIgnoreCase);

foreach (var field in requiredPropertyNames)
{
if (field.Contains('.'))
{
// Navigation path like "Parent.Id"
var parts = field.Split('.', 2);
var navName = parts[0];
var navProperty = parts[1];

if (!navigationPaths.TryGetValue(navName, out var properties))
{
properties = new(StringComparer.OrdinalIgnoreCase);
navigationPaths[navName] = properties;
}

// Only add if it doesn't contain further dots (single-level navigation)
if (!navProperty.Contains('.'))
{
properties.Add(navProperty);
}
}
else
{
// Simple field - check if it's a navigation property
var isNavigation = navigationProperties?.ContainsKey(field) == true;

if (isNavigation ||
projection.ScalarFields.Contains(field) ||
projection.KeyNames?.Contains(field, StringComparer.OrdinalIgnoreCase) == true)
{
// Skip navigation names - they'll be handled via navigation paths
continue;
}

scalarFieldsToAdd.Add(field);
}
}

// Merge scalar fields
var mergedScalars = new HashSet<string>(projection.ScalarFields, StringComparer.OrdinalIgnoreCase);
foreach (var field in scalarFieldsToAdd)
{
mergedScalars.Add(field);
}

// Merge navigations
var infos = projection.Navigations;
Dictionary<string, NavigationProjectionInfo> mergedNavigations;
if (infos == null)
{
mergedNavigations = [];
}
else
{
mergedNavigations = new(infos);
}

// Process navigation paths from filter fields
foreach (var (navName, requiredProps) in navigationPaths)
{
// Skip if no navigation metadata available for this entity type
if (navigationProperties == null)
{
continue;
}

// Try to find the navigation - use case-insensitive search
Navigation? navMetadata = null;
foreach (var (key, value) in navigationProperties)
{
if (string.Equals(key, navName, StringComparison.OrdinalIgnoreCase))
{
navMetadata = value;
break;
}
}

if (navMetadata == null)
{
continue;
}

var navType = navMetadata.Type;
if (mergedNavigations.TryGetValue(navName, out var existingNav))
{
// Navigation exists in GraphQL query - add filter-required properties to its projection
var updatedScalars = new HashSet<string>(existingNav.Projection.ScalarFields, StringComparer.OrdinalIgnoreCase);
foreach (var prop in requiredProps)
{
updatedScalars.Add(prop);
}

var updatedProjection = existingNav.Projection with
{
ScalarFields = updatedScalars
};
mergedNavigations[navName] = existingNav with
{
Projection = updatedProjection
};
}
else
{
// Create navigation projection for filter-only navigations
// Note: For abstract types, SelectExpressionBuilder.TryBuild will return false,
// causing the entire projection to fail. This is intentional - it ensures
// Include (added in AddFilterNavigationIncludes) is used instead of Select.
// Don't include key/FK columns for filter-only navigations - the filter only
// needs the specific properties it accesses.
var navProjection = new FieldProjectionInfo(requiredProps, null, null, null);
mergedNavigations[navName] = new(navType, navMetadata.IsCollection, navProjection);
}
}

return projection with
{
ScalarFields = mergedScalars,
Navigations = mergedNavigations
};
}

public IEnumerable<string> GetAbstractNavigationIncludes(
IReadOnlyDictionary<string, Navigation>? navigationProperties)
{
if (navigationProperties == null)
{
yield break;
}

// Extract navigation names from filter fields (paths like "Parent.Property" -> "Parent")
var navigationNames = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
foreach (var field in requiredPropertyNames)
{
if (field.Contains('.'))
{
var navName = field[..field.IndexOf('.')];
navigationNames.Add(navName);
}
}

// Return only navigations that have abstract types
foreach (var navName in navigationNames)
{
// Find navigation in metadata (case-insensitive)
Navigation? navMetadata = null;
string? actualNavName = null;
foreach (var (key, value) in navigationProperties)
{
if (string.Equals(key, navName, StringComparison.OrdinalIgnoreCase))
{
navMetadata = value;
actualNavName = key;
break;
}
}

if (navMetadata == null || actualNavName == null)
{
continue;
}

// Only return abstract types - concrete types can use projection
if (navMetadata.Type.IsAbstract)
{
yield return navMetadata.Name;
}
}
}

public Task<bool> ShouldIncludeWithProjection(
object userContext,
Expand Down
64 changes: 26 additions & 38 deletions src/GraphQL.EntityFramework/Filters/Filters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,38 +76,36 @@ internal void Add<TEntity, TProjection>(

Dictionary<Type, IFilterEntry<TDbContext>> entries = [];

public IReadOnlySet<string> GetRequiredFilterProperties<TEntity>()
/// <summary>
/// Get all filters that apply to the specified entity type (including base type filters).
/// </summary>
internal IEnumerable<IFilterEntry<TDbContext>> GetFilters<TEntity>()
where TEntity : class
{
var result = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
var filterEntries = FindFilters<TEntity>();

foreach (var entry in filterEntries)
{
foreach (var prop in entry.RequiredPropertyNames)
{
result.Add(prop);
}
}

return result;
var type = typeof(TEntity);
return entries
.Where(_ => _.Key.IsAssignableFrom(type))
.Select(_ => _.Value);
}

public IReadOnlyDictionary<Type, IReadOnlySet<string>> GetAllRequiredFilterProperties()
{
var result = new Dictionary<Type, IReadOnlySet<string>>();
/// <summary>
/// Get all filters that apply to the specified entity type (including base type filters).
/// </summary>
internal IEnumerable<IFilterEntry<TDbContext>> GetFilters(Type entityType) =>
entries
.Where(_ => _.Key.IsAssignableFrom(entityType))
.Select(_ => _.Value);

foreach (var (entityType, entry) in entries)
{
var props = entry.RequiredPropertyNames;
if (props.Count > 0)
{
result[entityType] = props;
}
}
/// <summary>
/// Get all registered filter entries.
/// </summary>
internal IEnumerable<IFilterEntry<TDbContext>> GetAllFilters() =>
entries.Values;

return result;
}
/// <summary>
/// Returns true if there are any filters registered.
/// </summary>
internal bool HasFilters => entries.Count > 0;

internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(
IEnumerable<TEntity> result,
Expand All @@ -121,7 +119,7 @@ internal virtual async Task<IEnumerable<TEntity>> ApplyFilter<TEntity>(
return result;
}

var filterEntries = FindFilters<TEntity>().ToList();
var filterEntries = GetFilters<TEntity>().ToList();
if (filterEntries.Count == 0)
{
return result;
Expand Down Expand Up @@ -175,22 +173,12 @@ internal virtual async Task<bool> ShouldInclude<TEntity>(
return true;
}

var filterEntries = FindFilters<TEntity>().ToList();
var filterEntries = GetFilters<TEntity>().ToList();
if (filterEntries.Count == 0)
{
return true;
}

return await ShouldIncludeItem(userContext, data, userPrincipal, item, filterEntries);
}

List<IFilterEntry<TDbContext>> FindFilters<TEntity>()
where TEntity : class
{
var type = typeof(TEntity);
return entries
.Where(_ => _.Key.IsAssignableFrom(type))
.Select(_ => _.Value)
.ToList();
}
}
21 changes: 20 additions & 1 deletion src/GraphQL.EntityFramework/Filters/IFilterEntry.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
interface IFilterEntry<TDbContext>
where TDbContext : DbContext
{
IReadOnlySet<string> RequiredPropertyNames { get; }
/// <summary>
/// Add this filter's requirements to the projection.
/// Returns the updated projection with filter-required fields and navigations merged in.
/// </summary>
/// <param name="projection">The current projection to merge requirements into.</param>
/// <param name="navigationProperties">Navigation property metadata for the entity type.</param>
/// <returns>Updated projection with filter requirements included.</returns>
FieldProjectionInfo AddRequirements(
FieldProjectionInfo projection,
IReadOnlyDictionary<string, Navigation>? navigationProperties);

/// <summary>
/// Get navigation names where the navigation type is abstract.
/// These navigations need Include() instead of projection because
/// abstract types cannot be instantiated in a Select expression.
/// </summary>
/// <param name="navigationProperties">Navigation property metadata for the entity type.</param>
/// <returns>Navigation names that require Include due to abstract types.</returns>
IEnumerable<string> GetAbstractNavigationIncludes(
IReadOnlyDictionary<string, Navigation>? navigationProperties);

Task<bool> ShouldIncludeWithProjection(
object userContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,7 @@ FieldType BuildFirstField<TSource, TReturn>(
query = query.AsNoTracking();
}

// Get filter-required fields early so we can add filter-required navigations via Include
var allFilterFields = fieldContext.Filters?.GetAllRequiredFilterProperties();

query = includeAppender.AddIncludesWithFiltersAndDetectNavigations(query, context, allFilterFields);
query = includeAppender.AddIncludesWithFiltersAndDetectNavigations(query, context, fieldContext.Filters);
query = query.ApplyGraphQlArguments(context, names, false, omitQueryArguments);

// Apply column projection based on requested GraphQL fields
Expand All @@ -176,7 +173,7 @@ FieldType BuildFirstField<TSource, TReturn>(
// Try to build projection even with abstract filter navigations
// The projection system may handle them (e.g., TPH inheritance)
// If projection build fails, we fall back to Include (which was already added above)
if (includeAppender.TryGetProjectionExpressionWithFilters<TReturn>(context, allFilterFields, out var selectExpr))
if (includeAppender.TryGetProjectionExpressionWithFilters<TDbContext, TReturn>(context, fieldContext.Filters, out var selectExpr))
{
query = query.Select(selectExpr);
}
Expand Down
Loading