Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add multi-provider support to SqlExpressionAttribute via Configuratio…
…n property

Co-authored-by: PhenX <42170+PhenX@users.noreply.github.com>
  • Loading branch information
Copilot and PhenX committed Mar 14, 2026
commit 1a5f62ed211250f5479321548e714a19790ffe24
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ public class SqlExpressionMethodCallTranslator : IMethodCallTranslator
new Regex(@"\{(\d+)\}", RegexOptions.Compiled);

private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly string? _providerName;

public SqlExpressionMethodCallTranslator(ISqlExpressionFactory sqlExpressionFactory)
public SqlExpressionMethodCallTranslator(ISqlExpressionFactory sqlExpressionFactory, string? providerName = null)
{
_sqlExpressionFactory = sqlExpressionFactory;
_providerName = providerName;
}

/// <inheritdoc />
Expand All @@ -38,11 +40,26 @@ public SqlExpressionMethodCallTranslator(ISqlExpressionFactory sqlExpressionFact
IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
var sqlExpressionAttr = method.GetCustomAttribute<SqlExpressionAttribute>();
if (sqlExpressionAttr is null)
var sqlExpressionAttrs = method.GetCustomAttributes<SqlExpressionAttribute>().ToArray();
if (sqlExpressionAttrs.Length == 0)
return null;

return TranslateTemplate(sqlExpressionAttr.Sql, arguments, method.ReturnType);
// Prefer an attribute whose Configuration matches the current provider name.
SqlExpressionAttribute? selectedAttr = null;
if (_providerName != null)
{
selectedAttr = sqlExpressionAttrs.FirstOrDefault(a =>
a.Configuration != null &&
_providerName.Contains(a.Configuration, StringComparison.OrdinalIgnoreCase));
}

// Fall back to an attribute without a Configuration (provider-agnostic).
selectedAttr ??= sqlExpressionAttrs.FirstOrDefault(a => a.Configuration is null);

if (selectedAttr is null)
return null;

return TranslateTemplate(selectedAttr.Sql, arguments, method.ReturnType);
}

private SqlExpression? TranslateTemplate(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query;

namespace EntityFrameworkCore.Projectables.Infrastructure.Internal
Expand All @@ -10,11 +11,12 @@ namespace EntityFrameworkCore.Projectables.Infrastructure.Internal
/// </summary>
public class SqlExpressionMethodCallTranslatorPlugin : IMethodCallTranslatorPlugin
{
public SqlExpressionMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory)
public SqlExpressionMethodCallTranslatorPlugin(ISqlExpressionFactory sqlExpressionFactory, ICurrentDbContext currentDbContext)
{
var providerName = currentDbContext.Context.Database.ProviderName;
Translators = new IMethodCallTranslator[]
{
new SqlExpressionMethodCallTranslator(sqlExpressionFactory)
new SqlExpressionMethodCallTranslator(sqlExpressionFactory, providerName)
};
}

Expand Down
16 changes: 15 additions & 1 deletion src/EntityFrameworkCore.Projectables/SqlExpressionAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace EntityFrameworkCore.Projectables
/// Decorates a static method with a SQL template string that will be used to translate
/// the method call into a SQL expression when used in a LINQ query against EF Core.
/// Use positional placeholders {0}, {1}, etc. to refer to the method arguments.
/// Multiple instances of this attribute may be applied to the same method, each with a
/// different <see cref="Configuration"/> value, to provide provider-specific SQL expressions.
/// </summary>
/// <example>
/// <code>
Expand All @@ -14,9 +16,14 @@ namespace EntityFrameworkCore.Projectables
///
/// [SqlExpression("COALESCE({0}, {1})")]
/// public static string Coalesce(string value, string fallback) => throw new NotImplementedException();
///
/// [SqlExpression("STRFTIME('%Y', {0})", Configuration = "Sqlite")]
/// [SqlExpression("YEAR({0})", Configuration = "SqlServer")]
/// [SqlExpression("EXTRACT(YEAR FROM {0})", Configuration = "Npgsql")]
/// public static int Year(DateTime date) => throw new NotImplementedException();
/// </code>
/// </example>
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
public sealed class SqlExpressionAttribute : Attribute
{
/// <summary>
Expand All @@ -40,5 +47,12 @@ public SqlExpressionAttribute(string sql)
/// throw <see cref="NotImplementedException"/> in its body.
/// </summary>
public bool ServerSideOnly { get; set; } = true;

/// <summary>
/// When set, this attribute only applies when the database provider name contains this value
/// (e.g. <c>"SqlServer"</c>, <c>"Sqlite"</c>, <c>"Npgsql"</c>).
/// When <c>null</c> (the default), the attribute acts as a fallback for any provider.
/// </summary>
public string? Configuration { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT YEAR([d].[CreatedAt])
FROM [DateEntity] AS [d]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT YEAR([d].[CreatedAt])
FROM [DateEntity] AS [d]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT YEAR([d].[CreatedAt])
FROM [DateEntity] AS [d]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT YEAR([d].[CreatedAt])
FROM [DateEntity] AS [d]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT YEAR([d].[CreatedAt])
FROM [DateEntity] AS [d]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT YEAR([d].[CreatedAt])
FROM [DateEntity] AS [d]
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ public static class Functions

[SqlExpression("COALESCE({0}, {1})")]
public static string Coalesce(string? value, string? fallback) => throw new NotImplementedException();

[SqlExpression("STRFTIME('%Y', {0})", Configuration = "Sqlite")]
[SqlExpression("YEAR({0})", Configuration = "SqlServer")]
[SqlExpression("EXTRACT(YEAR FROM {0})", Configuration = "Npgsql")]
public static int Year(DateTime date) => throw new NotImplementedException();

[SqlExpression("GENERIC_YEAR({0})")]
[SqlExpression("YEAR({0})", Configuration = "SqlServer")]
public static int YearWithFallback(DateTime date) => throw new NotImplementedException();
}

public record Entity
Expand All @@ -27,6 +36,12 @@ public record Entity
public string? NickName { get; set; }
}

public record DateEntity
{
public int Id { get; set; }
public DateTime CreatedAt { get; set; }
}

[Fact]
public Task WhereWithSqlExpressionUpper()
{
Expand All @@ -48,5 +63,27 @@ public Task SelectWithSqlExpressionCoalesce()

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task SelectWithProviderSpecificSqlExpression()
{
using var dbContext = new SampleDbContext<DateEntity>();

var query = dbContext.Set<DateEntity>()
.Select(x => Functions.Year(x.CreatedAt));

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task SelectWithFallbackSqlExpression()
{
using var dbContext = new SampleDbContext<DateEntity>();

var query = dbContext.Set<DateEntity>()
.Select(x => Functions.YearWithFallback(x.CreatedAt));

return Verifier.Verify(query.ToQueryString());
}
}
}