Skip to content

Commit 28999d5

Browse files
authored
Fix nullability processing for aggregate PostgresFunctionExpression (#2396)
1 parent 596ef7e commit 28999d5

2 files changed

Lines changed: 51 additions & 27 deletions

File tree

src/EFCore.PG/Query/Expressions/Internal/PostgresFunctionExpression.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ public override SqlFunctionExpression Update(SqlExpression? instance, IReadOnlyL
217217
: this;
218218
}
219219

220-
public virtual SqlFunctionExpression UpdateAggregateComponents(SqlExpression? predicate, IReadOnlyList<OrderingExpression> orderings)
220+
public virtual PostgresFunctionExpression UpdateAggregateComponents(
221+
SqlExpression? predicate,
222+
IReadOnlyList<OrderingExpression> orderings)
221223
{
222224
return predicate != AggregatePredicate || orderings != AggregateOrderings
223225
? new PostgresFunctionExpression(

src/EFCore.PG/Query/Internal/NpgsqlSqlNullabilityProcessor.cs

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ PostgresArrayIndexExpression arrayIndexExpression
3636
=> VisitArrayIndex(arrayIndexExpression, allowOptimizedExpansion, out nullable),
3737
PostgresBinaryExpression binaryExpression
3838
=> VisitBinary(binaryExpression, allowOptimizedExpansion, out nullable),
39-
PostgresFunctionExpression postgresFunctionExpression
40-
=> VisitPostgresFunction(postgresFunctionExpression, allowOptimizedExpansion, out nullable),
4139
PostgresILikeExpression ilikeExpression
4240
=> VisitILike(ilikeExpression, allowOptimizedExpansion, out nullable),
4341
PostgresJsonTraversalExpression postgresJsonTraversalExpression
@@ -51,6 +49,8 @@ PostgresRowValueExpression postgresRowValueExpression
5149
PostgresUnknownBinaryExpression postgresUnknownBinaryExpression
5250
=> VisitUnknownBinary(postgresUnknownBinaryExpression, allowOptimizedExpansion, out nullable),
5351

52+
// PostgresFunctionExpression is visited via the SqlFunctionExpression override below
53+
5454
_ => base.VisitCustomSqlExpression(sqlExpression, allowOptimizedExpansion, out nullable)
5555
};
5656

@@ -182,46 +182,68 @@ protected virtual SqlExpression VisitBinary(
182182
return binaryExpression.Update(left, right);
183183
}
184184

185-
protected virtual SqlExpression VisitPostgresFunction(
186-
PostgresFunctionExpression functionExpression,
185+
protected override SqlExpression VisitSqlFunction(
186+
SqlFunctionExpression sqlFunctionExpression,
187187
bool allowOptimizedExpansion,
188188
out bool nullable)
189189
{
190190
// PostgresFunctionExpression extends SqlFunctionExpression, and adds aggregate predicate and ordering expressions to that.
191191
// First call the base VisitSqlFunction to visit the arguments
192-
var visitedBase =
193-
(PostgresFunctionExpression)base.VisitSqlFunction(functionExpression, allowOptimizedExpansion, out nullable);
192+
var visitedBase = base.VisitSqlFunction(sqlFunctionExpression, allowOptimizedExpansion, out nullable);
193+
194+
// base.VisitSqlFunction has some special logic for SUM which wraps it in a COALESCE
195+
// (see https://github.com/dotnet/efcore/issues/28158), so we need some special handling to properly visit the
196+
// PostgresFunctionExpression it wraps.
197+
if (sqlFunctionExpression.IsBuiltIn
198+
&& string.Equals(sqlFunctionExpression.Name, "SUM", StringComparison.OrdinalIgnoreCase)
199+
&& visitedBase is SqlFunctionExpression { Name: "COALESCE", Arguments: { } } coalesceExpression
200+
&& coalesceExpression.Arguments[0] is PostgresFunctionExpression wrappedFunctionExpression)
201+
{
202+
var visitedArguments = coalesceExpression.Arguments!.ToArray();
203+
visitedArguments[0] = VisitPostgresFunctionComponents(wrappedFunctionExpression);
194204

195-
var aggregateChanged = false;
205+
return coalesceExpression.Update(coalesceExpression.Instance, visitedArguments);
206+
}
196207

197-
var visitedAggregatePredicate = Visit(functionExpression.AggregatePredicate, allowOptimizedExpansion: true, out _);
198-
aggregateChanged |= visitedAggregatePredicate != functionExpression.AggregatePredicate;
208+
return visitedBase is PostgresFunctionExpression pgFunctionExpression
209+
? VisitPostgresFunctionComponents(pgFunctionExpression)
210+
: visitedBase;
199211

200-
OrderingExpression[]? visitedOrderings = null;
201-
for (var i = 0; i < functionExpression.AggregateOrderings.Count; i++)
212+
PostgresFunctionExpression VisitPostgresFunctionComponents(PostgresFunctionExpression pgFunctionExpression)
202213
{
203-
var ordering = functionExpression.AggregateOrderings[i];
204-
var visitedOrdering = ordering.Update(Visit(ordering.Expression, out _));
205-
if (visitedOrdering != ordering && visitedOrderings is null)
214+
var aggregateChanged = false;
215+
216+
var visitedAggregatePredicate = Visit(pgFunctionExpression.AggregatePredicate, allowOptimizedExpansion: true, out _);
217+
aggregateChanged |= visitedAggregatePredicate != pgFunctionExpression.AggregatePredicate;
218+
219+
OrderingExpression[]? visitedOrderings = null;
220+
for (var i = 0; i < pgFunctionExpression.AggregateOrderings.Count; i++)
206221
{
207-
visitedOrderings = new OrderingExpression[functionExpression.AggregateOrderings.Count];
208-
for (var j = 0; j < i; j++)
222+
var ordering = pgFunctionExpression.AggregateOrderings[i];
223+
var visitedOrdering = ordering.Update(Visit(ordering.Expression, out _));
224+
if (visitedOrdering != ordering && visitedOrderings is null)
209225
{
210-
visitedOrderings[j] = functionExpression.AggregateOrderings[j];
226+
visitedOrderings = new OrderingExpression[pgFunctionExpression.AggregateOrderings.Count];
227+
for (var j = 0; j < i; j++)
228+
{
229+
visitedOrderings[j] = pgFunctionExpression.AggregateOrderings[j];
230+
}
231+
232+
aggregateChanged = true;
211233
}
212234

213-
aggregateChanged = true;
235+
if (visitedOrderings is not null)
236+
{
237+
visitedOrderings[i] = visitedOrdering;
238+
}
214239
}
215240

216-
if (visitedOrderings is not null)
217-
{
218-
visitedOrderings[i] = visitedOrdering;
219-
}
241+
return aggregateChanged
242+
? pgFunctionExpression.UpdateAggregateComponents(
243+
visitedAggregatePredicate,
244+
visitedOrderings ?? pgFunctionExpression.AggregateOrderings)
245+
: pgFunctionExpression;
220246
}
221-
222-
return aggregateChanged
223-
? visitedBase.UpdateAggregateComponents(visitedAggregatePredicate, visitedOrderings ?? functionExpression.AggregateOrderings)
224-
: visitedBase;
225247
}
226248

227249
protected virtual SqlExpression VisitILike(

0 commit comments

Comments
 (0)