diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SetOpToFilterRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SetOpToFilterRule.java index e00e3e50f685..93899ba9fefb 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/SetOpToFilterRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/SetOpToFilterRule.java @@ -18,12 +18,17 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Intersect; import org.apache.calcite.rel.core.Minus; +import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.SetOp; +import org.apache.calcite.rel.core.Sort; import org.apache.calcite.rel.core.Union; +import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.tools.RelBuilder; @@ -117,13 +122,18 @@ protected SetOpToFilterRule(Config config) { private static void match(RelOptRuleCall call) { final SetOp setOp = call.rel(0); + final RelMetadataQuery mq = call.getMetadataQuery(); final List inputs = setOp.getInputs(); if (setOp.all || inputs.size() < 2) { return; } final RelBuilder builder = call.builder(); - Pair first = extractSourceAndCond(inputs.get(0).stripped()); + final RelNode firstClause = inputs.get(0).stripped(); + final List firstCollations = mq.collations(firstClause); + Pair first = + extractSourceAndCond(firstClause, firstCollations != null + && firstCollations.stream().anyMatch(c -> c != RelCollations.EMPTY)); // Groups conditions by their source relational node and input position. // - Key: Pair of (sourceRelNode, inputPosition) @@ -143,7 +153,14 @@ private static void match(RelOptRuleCall call) { for (int i = 1; i < inputs.size(); i++) { final RelNode input = inputs.get(i).stripped(); - final Pair pair = extractSourceAndCond(input); + boolean isSorted = false; + final List inputCollations = mq.collations(input); + if (inputCollations != null + && inputCollations.stream().anyMatch(c -> c != RelCollations.EMPTY) + && inputCollations.equals(firstCollations)) { + isSorted = true; + } + final Pair pair = extractSourceAndCond(input, isSorted); sourceToConds.computeIfAbsent(Pair.of(pair.left, pair.right != null ? null : i), k -> new ArrayList<>()).add(pair.right); } @@ -197,7 +214,8 @@ private static RelBuilder buildSetOp(RelBuilder builder, int count, RelNode setO throw new IllegalStateException("unreachable code"); } - private static Pair extractSourceAndCond(RelNode input) { + private static Pair extractSourceAndCond(RelNode input, + boolean isSorted) { if (input instanceof Filter) { Filter filter = (Filter) input; if (!RexUtil.isDeterministic(filter.getCondition()) @@ -205,13 +223,37 @@ private static RelBuilder buildSetOp(RelBuilder builder, int count, RelNode setO // Skip non-deterministic conditions or those containing subqueries return Pair.of(input, null); } - return Pair.of(filter.getInput().stripped(), filter.getCondition()); + final RelNode source = filter.getInput().stripped(); + if (containsBlockingSortInProjectFilterChain(source, isSorted)) { + return Pair.of(input, null); + } + return Pair.of(source, filter.getCondition()); + } + if (containsBlockingSortInProjectFilterChain(input, isSorted)) { + return Pair.of(input, null); } // For non-filter inputs, use TRUE literal as default condition. return Pair.of(input.stripped(), input.getCluster().getRexBuilder().makeLiteral(true)); } + private static boolean containsBlockingSortInProjectFilterChain(RelNode input, + boolean isSorted) { + RelNode current = input.stripped(); + while (true) { + if (current instanceof Sort) { + Sort sort = (Sort) current; + return !isSorted && (sort.fetch != null || sort.offset != null); + } + if (current instanceof Project + || current instanceof Filter) { + current = current.getInput(0).stripped(); + continue; + } + return false; + } + } + /** * Creates a combined condition where the first condition * is kept as-is and all subsequent conditions are negated, diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index bd983836d216..7b664150a0e6 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -11435,6 +11435,119 @@ private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) { .check(); } + /** Test case of + * [CALCITE-7463] + * UnionToFilterRule incorrectly rewrites UNION with LIMIT. */ + @Test void testUnionToFilterRuleWithLimit() { + final String sql = "(SELECT mgr, comm FROM emp LIMIT 2)\n" + + "UNION\n" + + "(SELECT mgr, comm FROM emp LIMIT 2)\n"; + sql(sql) + .withRule(CoreRules.UNION_FILTER_TO_FILTER) + .checkUnchanged(); + } + + /** Test case of + * [CALCITE-7463] + * UnionToFilterRule incorrectly rewrites UNION with LIMIT. */ + @Test void testUnionAllToFilterRuleWithLimit() { + final String sql = "(SELECT mgr, comm FROM emp LIMIT 2)\n" + + "UNION ALL\n" + + "(SELECT mgr, comm FROM emp LIMIT 2)\n"; + sql(sql) + .withRule(CoreRules.UNION_FILTER_TO_FILTER) + .checkUnchanged(); + } + + /** Test case of + * [CALCITE-7463] + * UnionToFilterRule incorrectly rewrites UNION with LIMIT. */ + @Test void testUnionToFilterRuleWithNestedLimit() { + final String sql = "SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t\n" + + "WHERE comm > 5\n" + + "UNION\n" + + "SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t\n" + + "WHERE comm > 10\n"; + sql(sql) + .withPreRule(CoreRules.PROJECT_FILTER_TRANSPOSE) + .withRule(CoreRules.UNION_FILTER_TO_FILTER) + .checkUnchanged(); + } + + /** Test case of + * [CALCITE-7463] + * UnionToFilterRule incorrectly rewrites UNION with LIMIT. */ + @Test void testUnionToFilterRuleWithSortOnly() { + final Function relFn = b -> { + final RelNode left = b.scan("EMP") + .project(b.field("MGR"), b.field("COMM")) + .sort(b.field(0)) + .filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(5))) + .build(); + final RelNode right = b.scan("EMP") + .project(b.field("MGR"), b.field("COMM")) + .sort(b.field(0)) + .filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(10))) + .build(); + return b.push(left) + .push(right) + .union(false, 2) + .build(); + }; + relFn(relFn) + .withRule(CoreRules.UNION_FILTER_TO_FILTER) + .check(); + } + + /** Test case of + * [CALCITE-7463] + * UnionToFilterRule incorrectly rewrites UNION with LIMIT. */ + @Test void testUnionToFilterRuleWithSortLimit() { + final Function relFn = b -> { + final RelNode left = b.scan("EMP") + .project(b.field("MGR"), b.field("COMM")) + .sortLimit(10, 2, b.field(0)) + .filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(5))) + .build(); + final RelNode right = b.scan("EMP") + .project(b.field("MGR"), b.field("COMM")) + .sortLimit(10, 2, b.field(0)) + .filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(10))) + .build(); + return b.push(left) + .push(right) + .union(false, 2) + .build(); + }; + relFn(relFn) + .withRule(CoreRules.UNION_FILTER_TO_FILTER) + .check(); + } + + /** Test case of + * [CALCITE-7463] + * UnionToFilterRule incorrectly rewrites UNION with LIMIT. */ + @Test void testUnionToFilterRuleWithUnmergeableFirstInput() { + final Function relFn = b -> { + final RelNode left = b.scan("EMP") + .project(b.field("MGR"), b.field("COMM")) + .sortLimit(10, 2, b.field(0)) + .filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(5))) + .build(); + final RelNode right = b.scan("EMP") + .project(b.field("MGR"), b.field("COMM")) + .filter(b.call(SqlStdOperatorTable.GREATER_THAN, b.field(1), b.literal(10))) + .build(); + return b.push(left) + .push(right) + .union(false, 2) + .build(); + }; + relFn(relFn) + .withRule(CoreRules.UNION_FILTER_TO_FILTER) + .checkUnchanged(); + } + /** Test case of * [CALCITE-7002] * Create an optimization rule to eliminate UNION diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index ece2ad5258aa..d332f15fd65c 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -21336,6 +21336,25 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$ LogicalFilter(condition=[AND(<(+($0, 50), 20), >=($cor0.DEPTNO, $9))]) LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], $f9=[+(30, $7)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + @@ -21495,6 +21514,50 @@ LogicalUnion(all=[false]) LogicalFilter(condition=[SEARCH($0, Sarg[5, 10])]) LogicalProject(DEPTNO=[$0]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) +]]> + + + + + + + + + + + + + 5 +UNION +SELECT comm FROM (SELECT mgr, comm FROM emp LIMIT 2) t +WHERE comm > 10 +]]> + + + ($0, 5)]) + LogicalProject(COMM=[$1]) + LogicalSort(fetch=[2]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalFilter(condition=[>($0, 10)]) + LogicalProject(COMM=[$1]) + LogicalSort(fetch=[2]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> @@ -21541,6 +21604,54 @@ LogicalUnion(all=[false]) LogicalAggregate(group=[{0, 1}]) LogicalProject(MGR=[$3], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + ($1, 5)]) + LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) + LogicalFilter(condition=[>($1, 10)]) + LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) +]]> + + + ($1, 5)]) + LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) +]]> + + + + + ($1, 5)]) + LogicalSort(sort0=[$0], dir0=[ASC]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) + LogicalFilter(condition=[>($1, 10)]) + LogicalSort(sort0=[$0], dir0=[ASC]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) +]]> + + + ($1, 5)]) + LogicalSort(sort0=[$0], dir0=[ASC]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) ]]> @@ -21599,6 +21710,20 @@ LogicalAggregate(group=[{0, 1}]) LogicalFilter(condition=[OR(=($0, 12), =($1, 5))]) LogicalProject(MGR=[$3], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + ($1, 5)]) + LogicalSort(sort0=[$0], dir0=[ASC], offset=[10], fetch=[2]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) + LogicalFilter(condition=[>($1, 10)]) + LogicalProject(MGR=[$3], COMM=[$6]) + LogicalTableScan(table=[[scott, EMP]]) ]]>