diff --git a/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java b/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java index d636d23fab9..6336b997e74 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java @@ -65,13 +65,9 @@ * {@code LoptOptimizeJoinRule} is only capable of producing left-deep joins; * this rule is capable of producing bushy joins. * - *

TODO: - *

    - *
  1. Join conditions that touch 1 factor. - *
  2. Join conditions that touch 3 factors. - *
  3. More than 1 join conditions that touch the same pair of factors, - * e.g. {@code t0.c1 = t1.c1 and t1.c2 = t0.c3} - *
+ *

TODO: Join conditions that touch exactly 1 factor are currently applied + * as a filter above the join tree rather than being pushed down to the + * individual table scan. * * @see CoreRules#MULTI_JOIN_OPTIMIZE_BUSHY */ @@ -130,9 +126,17 @@ public MultiJoinOptimizeBushyRule(RelFactories.JoinFactory joinFactory, } assert x == multiJoin.getNumTotalFields(); + final List remainingConditions = new ArrayList<>(); final List unusedEdges = new ArrayList<>(); for (RexNode node : multiJoin.getJoinFilters()) { - unusedEdges.add(multiJoin.createEdge(node)); + LoptMultiJoin.Edge edge = multiJoin.createEdge(node); + if (edge.factors.cardinality() == 2) { + unusedEdges.add(edge); + } else { + // Conditions touching 1 or 3+ factors cannot be used as binary join + // edges. Re-apply them as a filter above the finished join tree. + remainingConditions.add(node); + } } // Comparator that chooses the best edge. A "good edge" is one that has @@ -170,11 +174,8 @@ private double rowCountDiff(LoptMultiJoin.Edge edge) { } else { final LoptMultiJoin.Edge bestEdge = unusedEdges.get(edgeOrdinal); - // For now, assume that the edge is between precisely two factors. - // 1-factor conditions have probably been pushed down, - // and 3-or-more-factor conditions are advanced. (TODO:) - // Therefore, for now, the factors that are merged are exactly the - // factors on this edge. + // Each edge in unusedEdges touches exactly two factors; conditions + // touching 1 or 3+ factors were separated out before the greedy loop. assert bestEdge.factors.cardinality() == 2; factors = bestEdge.factors.toArray(); } @@ -299,8 +300,17 @@ private double rowCountDiff(LoptMultiJoin.Edge edge) { } final Pair top = Util.last(relNodes); - relBuilder.push(top.left) - .project(relBuilder.fields(top.right)); + relBuilder.push(top.left); + if (!remainingConditions.isEmpty()) { + final RexVisitor shuttle = + new RexPermuteInputsShuttle(top.right, top.left); + final List remapped = new ArrayList<>(); + for (RexNode c : remainingConditions) { + remapped.add(c.accept(shuttle)); + } + relBuilder.filter(remapped); + } + relBuilder.project(relBuilder.fields(top.right)); call.transformTo(relBuilder.build()); } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 0f5959fb574..715d9f4f7b8 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -4344,6 +4344,28 @@ private void checkPushJoinThroughUnionOnRightDoesNotMatchSemiOrAntiJoin(JoinRelT .checkUnchanged(); } + /** Test case for the TODO in {@link MultiJoinOptimizeBushyRule}: + * "Join conditions that touch 3 factors." + * + *

The CASE condition references three factors (e1, d, e2) and therefore + * cannot be represented as a binary join edge. The rule should handle it + * gracefully rather than throwing an {@code AssertionError}. */ + @Test void testMultiJoinOptimizeBushyThreeFactorCondition() { + HepProgram preProgram = new HepProgramBuilder() + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN) + .build(); + HepProgram program = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY) + .build(); + final String sql = "select e1.ename from emp e1, dept d, emp e2\n" + + "where e1.deptno = d.deptno and e2.deptno = d.deptno\n" + + "and d.deptno = case when e1.sal > 1000 then e2.empno else e1.empno end"; + sql(sql).withPre(preProgram).withProgram(program).check(); + } + @Test void testConvertMultiJoinRule() { final String sql = "select e1.ename from emp e1, dept d, emp e2\n" + "where e1.deptno = d.deptno and d.deptno = e2.deptno"; diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index fc60b886497..c782bb913cc 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -10775,6 +10775,34 @@ LogicalAggregate(group=[{0, 1}]) LogicalFilter(condition=[AND(=($0, 12), <>($1, 5))]) LogicalProject(MGR=[$3], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + 1000 then e2.empno else e1.empno end]]> + + + ($5, 1000), $11, $0)), =($7, $9))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($16, 1000), $2, $11))]) + LogicalJoin(condition=[=($18, $0)], joinType=[inner]) + LogicalJoin(condition=[=($9, $0)], joinType=[inner]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]>