From a3b97fa1101906bc6b98e5474cd0092d09a33e23 Mon Sep 17 00:00:00 2001 From: Sean Broeder Date: Thu, 14 May 2026 07:42:20 -0700 Subject: [PATCH] [CALCITE-7514] MultiJoinOptimizeBushyRule throws AssertionError when a join condition references 3 or more factors Conditions in a MultiJoin's joinFilters that reference anything other than exactly two factors cannot be represented as binary join edges. Passing such a condition to createEdge produced an edge with factors.cardinality() != 2, causing an AssertionError in the edge comparator's rowCountDiff method, and at two further assertion sites in the greedy loop. The fix separates these conditions from the edge list upfront. After the greedy join-ordering loop completes, the remaining conditions are remapped from original MultiJoin field positions to the final join tree's output positions via RexPermuteInputsShuttle, then applied as a LogicalFilter above the join tree before the reordering project. For inner joins this is semantically equivalent to applying them as join predicates. Two TODO items are resolved: - "Join conditions that touch 3 factors" is fully handled. - "More than 1 join conditions that touch the same pair of factors" was stale from the original commit; the conditions loop already collects all edges subsumed by newFactors at each greedy step. A remaining TODO notes that 1-factor conditions are applied as a filter above the join tree rather than pushed down to the individual scan. --- .../rel/rules/MultiJoinOptimizeBushyRule.java | 40 ++++++++++++------- .../apache/calcite/test/RelOptRulesTest.java | 22 ++++++++++ .../apache/calcite/test/RelOptRulesTest.xml | 28 +++++++++++++ 3 files changed, 75 insertions(+), 15 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java b/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java index d636d23fab92..6336b997e742 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java @@ -65,13 +65,9 @@ * {@code LoptOptimizeJoinRule} is only capable of producing left-deep joins; * this rule is capable of producing bushy joins. * - *

TODO: - *

    - *
  1. Join conditions that touch 1 factor. - *
  2. Join conditions that touch 3 factors. - *
  3. More than 1 join conditions that touch the same pair of factors, - * e.g. {@code t0.c1 = t1.c1 and t1.c2 = t0.c3} - *
+ *

TODO: Join conditions that touch exactly 1 factor are currently applied + * as a filter above the join tree rather than being pushed down to the + * individual table scan. * * @see CoreRules#MULTI_JOIN_OPTIMIZE_BUSHY */ @@ -130,9 +126,17 @@ public MultiJoinOptimizeBushyRule(RelFactories.JoinFactory joinFactory, } assert x == multiJoin.getNumTotalFields(); + final List remainingConditions = new ArrayList<>(); final List unusedEdges = new ArrayList<>(); for (RexNode node : multiJoin.getJoinFilters()) { - unusedEdges.add(multiJoin.createEdge(node)); + LoptMultiJoin.Edge edge = multiJoin.createEdge(node); + if (edge.factors.cardinality() == 2) { + unusedEdges.add(edge); + } else { + // Conditions touching 1 or 3+ factors cannot be used as binary join + // edges. Re-apply them as a filter above the finished join tree. + remainingConditions.add(node); + } } // Comparator that chooses the best edge. A "good edge" is one that has @@ -170,11 +174,8 @@ private double rowCountDiff(LoptMultiJoin.Edge edge) { } else { final LoptMultiJoin.Edge bestEdge = unusedEdges.get(edgeOrdinal); - // For now, assume that the edge is between precisely two factors. - // 1-factor conditions have probably been pushed down, - // and 3-or-more-factor conditions are advanced. (TODO:) - // Therefore, for now, the factors that are merged are exactly the - // factors on this edge. + // Each edge in unusedEdges touches exactly two factors; conditions + // touching 1 or 3+ factors were separated out before the greedy loop. assert bestEdge.factors.cardinality() == 2; factors = bestEdge.factors.toArray(); } @@ -299,8 +300,17 @@ private double rowCountDiff(LoptMultiJoin.Edge edge) { } final Pair top = Util.last(relNodes); - relBuilder.push(top.left) - .project(relBuilder.fields(top.right)); + relBuilder.push(top.left); + if (!remainingConditions.isEmpty()) { + final RexVisitor shuttle = + new RexPermuteInputsShuttle(top.right, top.left); + final List remapped = new ArrayList<>(); + for (RexNode c : remainingConditions) { + remapped.add(c.accept(shuttle)); + } + relBuilder.filter(remapped); + } + relBuilder.project(relBuilder.fields(top.right)); call.transformTo(relBuilder.build()); } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 0f5959fb574b..715d9f4f7b8e 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -4344,6 +4344,28 @@ private void checkPushJoinThroughUnionOnRightDoesNotMatchSemiOrAntiJoin(JoinRelT .checkUnchanged(); } + /** Test case for the TODO in {@link MultiJoinOptimizeBushyRule}: + * "Join conditions that touch 3 factors." + * + *

The CASE condition references three factors (e1, d, e2) and therefore + * cannot be represented as a binary join edge. The rule should handle it + * gracefully rather than throwing an {@code AssertionError}. */ + @Test void testMultiJoinOptimizeBushyThreeFactorCondition() { + HepProgram preProgram = new HepProgramBuilder() + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN) + .build(); + HepProgram program = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY) + .build(); + final String sql = "select e1.ename from emp e1, dept d, emp e2\n" + + "where e1.deptno = d.deptno and e2.deptno = d.deptno\n" + + "and d.deptno = case when e1.sal > 1000 then e2.empno else e1.empno end"; + sql(sql).withPre(preProgram).withProgram(program).check(); + } + @Test void testConvertMultiJoinRule() { final String sql = "select e1.ename from emp e1, dept d, emp e2\n" + "where e1.deptno = d.deptno and d.deptno = e2.deptno"; diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index fc60b8864971..c782bb913cc1 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -10775,6 +10775,34 @@ LogicalAggregate(group=[{0, 1}]) LogicalFilter(condition=[AND(=($0, 12), <>($1, 5))]) LogicalProject(MGR=[$3], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + 1000 then e2.empno else e1.empno end]]> + + + ($5, 1000), $11, $0)), =($7, $9))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($16, 1000), $2, $11))]) + LogicalJoin(condition=[=($18, $0)], joinType=[inner]) + LogicalJoin(condition=[=($9, $0)], joinType=[inner]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]>