Skip to content

Commit d4b59dd

Browse files
szehon-hocloud-fan
authored andcommitted
[SPARK-56383][SQL][TEST] Add tests for partition filter extraction from mixed predicates
### What changes were proposed in this pull request? Add tests to `DataSourceV2EnhancedPartitionFilterSuite` covering the case where `getPartitionFiltersAndDataFilters` (called from `PushDownUtils.pushPartitionPredicates`) extracts additional partition filters from predicates that reference both partition and data columns via `extractPredicatesWithinOutputSet`. This was one of the goals of SPARK-55596 ### Why are the changes needed? Test coverage ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Ran tests ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude (Cursor) Closes #55248 from szehon-ho/partition_predicate_from_mixed. Authored-by: Szehon Ho <szehon.apache@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 4db309d commit d4b59dd

1 file changed

Lines changed: 144 additions & 0 deletions

File tree

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2EnhancedPartitionFilterSuite.scala

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,150 @@ class DataSourceV2EnhancedPartitionFilterSuite
409409
}
410410
}
411411

412+
test("extract partition filter from translated OR with mixed partition and data references") {
413+
withTable(partFilterTableName) {
414+
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
415+
"PARTITIONED BY (part_col)")
416+
sql(s"INSERT INTO $partFilterTableName VALUES " +
417+
"('a', 'x'), ('a', 'other'), ('b', 'y'), ('c', 'z')")
418+
419+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
420+
"(part_col = 'a' AND data = 'x') OR (part_col = 'b' AND data = 'y')")
421+
checkAnswer(df, Seq(Row("a", "x"), Row("b", "y")))
422+
assertPushedPartitionPredicates(df, 1)
423+
assertScanReturnsPartitionKeys(df, Set("a", "b"))
424+
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
425+
}
426+
}
427+
428+
test("extract partition filter from untranslatable OR with mixed partition and data references") {
429+
withTable(partFilterTableName) {
430+
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
431+
"PARTITIONED BY (part_col)")
432+
sql(s"INSERT INTO $partFilterTableName VALUES " +
433+
"('a', 'x'), ('b', 'y'), ('c', 'z')")
434+
435+
spark.udf.register("my_upper_extract", (s: String) =>
436+
if (s == null) null else s.toUpperCase(Locale.ROOT))
437+
438+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
439+
"(my_upper_extract(part_col) = 'A' AND data = 'x') OR " +
440+
"(my_upper_extract(part_col) = 'B' AND data = 'y')")
441+
checkAnswer(df, Seq(Row("a", "x"), Row("b", "y")))
442+
assertPushedPartitionPredicates(df, 1)
443+
assertScanReturnsPartitionKeys(df, Set("a", "b"))
444+
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
445+
}
446+
}
447+
448+
test("extract partition filter from OR with one partition-only and one mixed filter") {
449+
withTable(partFilterTableName) {
450+
sql(s"CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
451+
"PARTITIONED BY (part_col)")
452+
sql(s"INSERT INTO $partFilterTableName VALUES " +
453+
"('a', 'x'), ('a', 'other'), ('b', 'y'), ('b', 'other'), ('c', 'z')")
454+
455+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
456+
"part_col = 'a' OR (part_col = 'b' AND data = 'y')")
457+
checkAnswer(df, Seq(Row("a", "x"), Row("a", "other"), Row("b", "y")))
458+
assertPushedPartitionPredicates(df, 1)
459+
assertScanReturnsPartitionKeys(df, Set("a", "b"))
460+
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("part_col"))
461+
}
462+
}
463+
464+
test("extract multi-column partition filter from OR with mixed partition and data references") {
465+
withTable(partFilterTableName) {
466+
sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " +
467+
s"USING $v2Source PARTITIONED BY (p1, p2)")
468+
sql(s"INSERT INTO $partFilterTableName VALUES " +
469+
"('a', 'x', 'd1'), ('a', 'y', 'd2'), ('b', 'x', 'd3'), ('b', 'y', 'd4')")
470+
471+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
472+
"(p1 = 'a' AND p2 = 'x' AND data = 'd1') OR (p1 = 'b' AND p2 = 'y' AND data = 'd4')")
473+
checkAnswer(df, Seq(Row("a", "x", "d1"), Row("b", "y", "d4")))
474+
assertPushedPartitionPredicates(df, 1)
475+
assertScanReturnsPartitionKeys(df, Set("a/x", "b/y"))
476+
assertReferencedPartitionFieldOrdinals(df, Array(0, 1), Array("p1", "p2"))
477+
}
478+
}
479+
480+
test("two partition predicates pushed: UDF on p1 and " +
481+
"extracted filter on p2 from mixed data and partition references") {
482+
withTable(partFilterTableName) {
483+
sql(s"CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " +
484+
s"USING $v2Source PARTITIONED BY (p1, p2)")
485+
sql(s"INSERT INTO $partFilterTableName VALUES " +
486+
"('a', 'x', 'd1'), " +
487+
"('a', 'y', 'd4'), " +
488+
"('b', 'x', 'd3'), " +
489+
"('b', 'y', 'd4'), " +
490+
"('c', 'z', 'd5')")
491+
492+
spark.udf.register("my_upper_multi", (s: String) =>
493+
if (s == null) null else s.toUpperCase(Locale.ROOT))
494+
495+
// my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is a partition filter.
496+
// The OR mixes p2 and data; we infer (p2 = 'x' OR p2 = 'y') as a partition filter.
497+
// Both are pushed as separate PartitionPredicates.
498+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
499+
"my_upper_multi(p1) = 'A' AND " +
500+
"((p2 = 'x' AND data = 'd1') OR (p2 = 'y' AND data = 'd4'))")
501+
checkAnswer(df, Seq(Row("a", "x", "d1"), Row("a", "y", "d4")))
502+
assertPushedPartitionPredicates(df, 2)
503+
assertScanReturnsPartitionKeys(df, Set("a/x", "a/y"))
504+
}
505+
}
506+
507+
test("nested partition: extract partition filter from " +
508+
"OR with mixed data and partition references") {
509+
withTable(partFilterTableName) {
510+
sql(s"CREATE TABLE $partFilterTableName " +
511+
s"(s struct<tz: string, x: int>, data string) USING $v2Source " +
512+
"PARTITIONED BY (s.tz)")
513+
sql(s"INSERT INTO $partFilterTableName VALUES " +
514+
"(named_struct('tz', 'LA', 'x', 1), 'a'), " +
515+
"(named_struct('tz', 'NY', 'x', 2), 'b'), " +
516+
"(named_struct('tz', 'SF', 'x', 3), 'c')")
517+
518+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
519+
"(s.tz = 'LA' AND data = 'a') OR (s.tz = 'NY' AND data = 'b')")
520+
checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("NY", 2), "b")))
521+
assertPushedPartitionPredicates(df, 1)
522+
assertScanReturnsPartitionKeys(df, Set("LA", "NY"))
523+
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz"))
524+
}
525+
}
526+
527+
test("nested partition: two partition predicates from " +
528+
"UDF and extracted mixed data and partition references") {
529+
withTable(partFilterTableName) {
530+
sql(s"CREATE TABLE $partFilterTableName " +
531+
s"(s struct<tz: string, x: int>, data string) USING $v2Source " +
532+
"PARTITIONED BY (s.tz)")
533+
sql(s"INSERT INTO $partFilterTableName VALUES " +
534+
"(named_struct('tz', 'LA', 'x', 1), 'a'), " +
535+
"(named_struct('tz', 'la', 'x', 2), 'b'), " +
536+
"(named_struct('tz', 'NY', 'x', 3), 'c'), " +
537+
"(named_struct('tz', 'SF', 'x', 4), 'd')")
538+
539+
spark.udf.register("my_upper_nested2", (s: String) =>
540+
if (s == null) null else s.toUpperCase(Locale.ROOT))
541+
542+
// my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only,
543+
// it is a partition filter.
544+
// The OR mixes s.tz and data; we infer (s.tz = 'LA' OR s.tz = 'la') as an partition filter.
545+
// Both are pushed as separate PartitionPredicates.
546+
val df = sql(s"SELECT * FROM $partFilterTableName WHERE " +
547+
"my_upper_nested2(s.tz) = 'LA' AND " +
548+
"((s.tz = 'LA' AND data = 'a') OR (s.tz = 'la' AND data = 'b'))")
549+
checkAnswer(df, Seq(Row(Row("LA", 1), "a"), Row(Row("la", 2), "b")))
550+
assertPushedPartitionPredicates(df, 2)
551+
assertScanReturnsPartitionKeys(df, Set("LA", "la"))
552+
assertReferencedPartitionFieldOrdinals(df, Array(0), Array("s.tz"))
553+
}
554+
}
555+
412556
private def assertTranslatableBeforeUntranslatableInPostScan(df: DataFrame): Unit = {
413557
val postScanFilterExec = df.queryExecution.executedPlan.collect {
414558
case f @ FilterExec(_, _) if f.exists(_.isInstanceOf[BatchScanExec]) => f

0 commit comments

Comments
 (0)