Skip to content

Commit 1016470

Browse files
committed
Infrastructure for custom aggregate translation
Takes care of the built-in aggregates only for now (Min/Max/etc.) Part of #727
1 parent 1249c0d commit 1016470

14 files changed

Lines changed: 490 additions & 153 deletions

src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-

1+
using System.Runtime.CompilerServices;
22

33
// ReSharper disable once CheckNamespace
4-
5-
using System.Runtime.CompilerServices;
6-
74
namespace Microsoft.EntityFrameworkCore;
85

96
/// <summary>

src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlMultirangeDbFunctionsExtensions.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
// ReSharper disable once CheckNamespace
2-
3-
4-
51
// ReSharper disable once CheckNamespace
62
namespace Microsoft.EntityFrameworkCore;
73

src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlRangeDbFunctionsExtensions.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-

2-
3-
// ReSharper disable once CheckNamespace
1+
// ReSharper disable once CheckNamespace
42
namespace Microsoft.EntityFrameworkCore;
53

64
/// <summary>

src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlTrigramsDbFunctionsExtensions.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-

2-
3-
// ReSharper disable once CheckNamespace
1+
// ReSharper disable once CheckNamespace
42
namespace Microsoft.EntityFrameworkCore;
53

64
public static class NpgsqlTrigramsDbFunctionsExtensions

src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public static IServiceCollection AddEntityFrameworkNpgsql(this IServiceCollectio
106106
.TryAdd<ICompiledQueryCacheKeyGenerator, NpgsqlCompiledQueryCacheKeyGenerator>()
107107
.TryAdd<IExecutionStrategyFactory, NpgsqlExecutionStrategyFactory>()
108108
.TryAdd<IMethodCallTranslatorProvider, NpgsqlMethodCallTranslatorProvider>()
109+
.TryAdd<IAggregateMethodCallTranslatorProvider, NpgsqlAggregateMethodCallTranslatorProvider>()
109110
.TryAdd<IMemberTranslatorProvider, NpgsqlMemberTranslatorProvider>()
110111
.TryAdd<IEvaluatableExpressionFilter, NpgsqlEvaluatableExpressionFilter>()
111112
.TryAdd<IQuerySqlGeneratorFactory, NpgsqlQuerySqlGeneratorFactory>()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;
2+
3+
public class NpgsqlAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider
4+
{
5+
public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies)
6+
: base(dependencies)
7+
{
8+
var sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory;
9+
var typeMappingSource = dependencies.RelationalTypeMappingSource;
10+
11+
AddTranslators(
12+
new IAggregateMethodCallTranslator[]
13+
{
14+
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource)
15+
});
16+
}
17+
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics;
2+
3+
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;
4+
5+
public class NpgsqlQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator
6+
{
7+
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
8+
private readonly IRelationalTypeMappingSource _typeMappingSource;
9+
10+
public NpgsqlQueryableAggregateMethodTranslator(
11+
NpgsqlSqlExpressionFactory sqlExpressionFactory,
12+
IRelationalTypeMappingSource typeMappingSource)
13+
{
14+
_sqlExpressionFactory = sqlExpressionFactory;
15+
_typeMappingSource = typeMappingSource;
16+
}
17+
18+
public virtual SqlExpression? Translate(
19+
MethodInfo method,
20+
EnumerableExpression source,
21+
IReadOnlyList<SqlExpression> arguments,
22+
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
23+
{
24+
if (method.DeclaringType == typeof(Queryable))
25+
{
26+
var methodInfo = method.IsGenericMethod
27+
? method.GetGenericMethodDefinition()
28+
: method;
29+
switch (methodInfo.Name)
30+
{
31+
case nameof(Queryable.Average)
32+
when (QueryableMethods.IsAverageWithoutSelector(methodInfo)
33+
|| QueryableMethods.IsAverageWithSelector(methodInfo))
34+
&& source.Selector is SqlExpression averageSqlExpression:
35+
var averageInputType = averageSqlExpression.Type;
36+
if (averageInputType == typeof(int)
37+
|| averageInputType == typeof(long))
38+
{
39+
averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping(
40+
_sqlExpressionFactory.Convert(averageSqlExpression, typeof(double)));
41+
}
42+
43+
return averageInputType == typeof(float)
44+
? _sqlExpressionFactory.Convert(
45+
_sqlExpressionFactory.AggregateFunction(
46+
"AVG",
47+
new[] { averageSqlExpression },
48+
nullable: true,
49+
argumentsPropagateNullability: FalseArrays[1],
50+
source,
51+
typeof(double)),
52+
averageSqlExpression.Type,
53+
averageSqlExpression.TypeMapping)
54+
: _sqlExpressionFactory.AggregateFunction(
55+
"AVG",
56+
new[] { averageSqlExpression },
57+
nullable: true,
58+
argumentsPropagateNullability: FalseArrays[1],
59+
source,
60+
averageSqlExpression.Type,
61+
averageSqlExpression.TypeMapping);
62+
63+
// PostgreSQL COUNT() always returns bigint, so we need to downcast to int
64+
case nameof(Queryable.Count)
65+
when methodInfo == QueryableMethods.CountWithoutPredicate
66+
|| methodInfo == QueryableMethods.CountWithPredicate:
67+
var countSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*");
68+
return _sqlExpressionFactory.Convert(
69+
_sqlExpressionFactory.ApplyDefaultTypeMapping(
70+
_sqlExpressionFactory.AggregateFunction(
71+
"COUNT",
72+
new[] { countSqlExpression },
73+
nullable: false,
74+
argumentsPropagateNullability: FalseArrays[1],
75+
source,
76+
typeof(long))),
77+
typeof(int), _typeMappingSource.FindMapping(typeof(int)));
78+
79+
case nameof(Queryable.LongCount)
80+
when methodInfo == QueryableMethods.LongCountWithoutPredicate
81+
|| methodInfo == QueryableMethods.LongCountWithPredicate:
82+
var longCountSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*");
83+
return _sqlExpressionFactory.ApplyDefaultTypeMapping(
84+
_sqlExpressionFactory.AggregateFunction(
85+
"COUNT",
86+
new[] { longCountSqlExpression },
87+
nullable: false,
88+
argumentsPropagateNullability: FalseArrays[1],
89+
source,
90+
typeof(long)));
91+
92+
case nameof(Queryable.Max)
93+
when (methodInfo == QueryableMethods.MaxWithoutSelector
94+
|| methodInfo == QueryableMethods.MaxWithSelector)
95+
&& source.Selector is SqlExpression maxSqlExpression:
96+
return _sqlExpressionFactory.AggregateFunction(
97+
"MAX",
98+
new[] { maxSqlExpression },
99+
nullable: true,
100+
argumentsPropagateNullability: FalseArrays[1],
101+
source,
102+
maxSqlExpression.Type,
103+
maxSqlExpression.TypeMapping);
104+
105+
case nameof(Queryable.Min)
106+
when (methodInfo == QueryableMethods.MinWithoutSelector
107+
|| methodInfo == QueryableMethods.MinWithSelector)
108+
&& source.Selector is SqlExpression minSqlExpression:
109+
return _sqlExpressionFactory.AggregateFunction(
110+
"MIN",
111+
new[] { minSqlExpression },
112+
nullable: true,
113+
argumentsPropagateNullability: FalseArrays[1],
114+
source,
115+
minSqlExpression.Type,
116+
minSqlExpression.TypeMapping);
117+
118+
// In PostgreSQL SUM() doesn't return the same type as its argument for smallint, int and bigint.
119+
// Cast to get the same type.
120+
// http://www.postgresql.org/docs/current/static/functions-aggregate.html
121+
case nameof(Queryable.Sum)
122+
when (QueryableMethods.IsSumWithoutSelector(methodInfo)
123+
|| QueryableMethods.IsSumWithSelector(methodInfo))
124+
&& source.Selector is SqlExpression sumSqlExpression:
125+
var sumInputType = sumSqlExpression.Type;
126+
127+
// Note that there is no Sum over short in LINQ
128+
if (sumInputType == typeof(int))
129+
{
130+
return _sqlExpressionFactory.Convert(
131+
_sqlExpressionFactory.AggregateFunction(
132+
"SUM",
133+
new[] { sumSqlExpression },
134+
nullable: true,
135+
argumentsPropagateNullability: FalseArrays[1],
136+
source,
137+
typeof(long)),
138+
sumInputType,
139+
sumSqlExpression.TypeMapping);
140+
}
141+
142+
if (sumInputType == typeof(long))
143+
{
144+
return _sqlExpressionFactory.Convert(
145+
_sqlExpressionFactory.AggregateFunction(
146+
"SUM",
147+
new[] { sumSqlExpression },
148+
nullable: true,
149+
argumentsPropagateNullability: FalseArrays[1],
150+
source,
151+
typeof(decimal)),
152+
sumInputType,
153+
sumSqlExpression.TypeMapping);
154+
}
155+
156+
return _sqlExpressionFactory.AggregateFunction(
157+
"SUM",
158+
new[] { sumSqlExpression },
159+
nullable: true,
160+
argumentsPropagateNullability: FalseArrays[1],
161+
source,
162+
sumInputType,
163+
sumSqlExpression.TypeMapping);
164+
}
165+
}
166+
167+
return null;
168+
}
169+
}

0 commit comments

Comments
 (0)