diff --git a/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java b/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java
index d636d23fab9..6336b997e74 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/MultiJoinOptimizeBushyRule.java
@@ -65,13 +65,9 @@
* {@code LoptOptimizeJoinRule} is only capable of producing left-deep joins;
* this rule is capable of producing bushy joins.
*
- *
TODO:
- *
- * - Join conditions that touch 1 factor.
- *
- Join conditions that touch 3 factors.
- *
- More than 1 join conditions that touch the same pair of factors,
- * e.g. {@code t0.c1 = t1.c1 and t1.c2 = t0.c3}
- *
+ * TODO: Join conditions that touch exactly 1 factor are currently applied
+ * as a filter above the join tree rather than being pushed down to the
+ * individual table scan.
*
* @see CoreRules#MULTI_JOIN_OPTIMIZE_BUSHY
*/
@@ -130,9 +126,17 @@ public MultiJoinOptimizeBushyRule(RelFactories.JoinFactory joinFactory,
}
assert x == multiJoin.getNumTotalFields();
+ final List remainingConditions = new ArrayList<>();
final List unusedEdges = new ArrayList<>();
for (RexNode node : multiJoin.getJoinFilters()) {
- unusedEdges.add(multiJoin.createEdge(node));
+ LoptMultiJoin.Edge edge = multiJoin.createEdge(node);
+ if (edge.factors.cardinality() == 2) {
+ unusedEdges.add(edge);
+ } else {
+ // Conditions touching 1 or 3+ factors cannot be used as binary join
+ // edges. Re-apply them as a filter above the finished join tree.
+ remainingConditions.add(node);
+ }
}
// Comparator that chooses the best edge. A "good edge" is one that has
@@ -170,11 +174,8 @@ private double rowCountDiff(LoptMultiJoin.Edge edge) {
} else {
final LoptMultiJoin.Edge bestEdge = unusedEdges.get(edgeOrdinal);
- // For now, assume that the edge is between precisely two factors.
- // 1-factor conditions have probably been pushed down,
- // and 3-or-more-factor conditions are advanced. (TODO:)
- // Therefore, for now, the factors that are merged are exactly the
- // factors on this edge.
+ // Each edge in unusedEdges touches exactly two factors; conditions
+ // touching 1 or 3+ factors were separated out before the greedy loop.
assert bestEdge.factors.cardinality() == 2;
factors = bestEdge.factors.toArray();
}
@@ -299,8 +300,17 @@ private double rowCountDiff(LoptMultiJoin.Edge edge) {
}
final Pair top = Util.last(relNodes);
- relBuilder.push(top.left)
- .project(relBuilder.fields(top.right));
+ relBuilder.push(top.left);
+ if (!remainingConditions.isEmpty()) {
+ final RexVisitor shuttle =
+ new RexPermuteInputsShuttle(top.right, top.left);
+ final List remapped = new ArrayList<>();
+ for (RexNode c : remainingConditions) {
+ remapped.add(c.accept(shuttle));
+ }
+ relBuilder.filter(remapped);
+ }
+ relBuilder.project(relBuilder.fields(top.right));
call.transformTo(relBuilder.build());
}
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 0f5959fb574..715d9f4f7b8 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -4344,6 +4344,28 @@ private void checkPushJoinThroughUnionOnRightDoesNotMatchSemiOrAntiJoin(JoinRelT
.checkUnchanged();
}
+ /** Test case for the TODO in {@link MultiJoinOptimizeBushyRule}:
+ * "Join conditions that touch 3 factors."
+ *
+ * The CASE condition references three factors (e1, d, e2) and therefore
+ * cannot be represented as a binary join edge. The rule should handle it
+ * gracefully rather than throwing an {@code AssertionError}. */
+ @Test void testMultiJoinOptimizeBushyThreeFactorCondition() {
+ HepProgram preProgram = new HepProgramBuilder()
+ .addRuleInstance(CoreRules.FILTER_INTO_JOIN)
+ .addMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN)
+ .build();
+ HepProgram program = new HepProgramBuilder()
+ .addMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .addRuleInstance(CoreRules.MULTI_JOIN_OPTIMIZE_BUSHY)
+ .build();
+ final String sql = "select e1.ename from emp e1, dept d, emp e2\n"
+ + "where e1.deptno = d.deptno and e2.deptno = d.deptno\n"
+ + "and d.deptno = case when e1.sal > 1000 then e2.empno else e1.empno end";
+ sql(sql).withPre(preProgram).withProgram(program).check();
+ }
+
@Test void testConvertMultiJoinRule() {
final String sql = "select e1.ename from emp e1, dept d, emp e2\n"
+ "where e1.deptno = d.deptno and d.deptno = e2.deptno";
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index fc60b886497..c782bb913cc 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -10775,6 +10775,34 @@ LogicalAggregate(group=[{0, 1}])
LogicalFilter(condition=[AND(=($0, 12), <>($1, 5))])
LogicalProject(MGR=[$3], COMM=[$6])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+
+
+ 1000 then e2.empno else e1.empno end]]>
+
+
+ ($5, 1000), $11, $0)), =($7, $9))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+
+
+ ($16, 1000), $2, $11))])
+ LogicalJoin(condition=[=($18, $0)], joinType=[inner])
+ LogicalJoin(condition=[=($9, $0)], joinType=[inner])
+ LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>