@@ -409,6 +409,150 @@ class DataSourceV2EnhancedPartitionFilterSuite
409409 }
410410 }
411411
412+ test(" extract partition filter from translated OR with mixed partition and data references" ) {
413+ withTable(partFilterTableName) {
414+ sql(s " CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
415+ " PARTITIONED BY (part_col)" )
416+ sql(s " INSERT INTO $partFilterTableName VALUES " +
417+ " ('a', 'x'), ('a', 'other'), ('b', 'y'), ('c', 'z')" )
418+
419+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
420+ " (part_col = 'a' AND data = 'x') OR (part_col = 'b' AND data = 'y')" )
421+ checkAnswer(df, Seq (Row (" a" , " x" ), Row (" b" , " y" )))
422+ assertPushedPartitionPredicates(df, 1 )
423+ assertScanReturnsPartitionKeys(df, Set (" a" , " b" ))
424+ assertReferencedPartitionFieldOrdinals(df, Array (0 ), Array (" part_col" ))
425+ }
426+ }
427+
428+ test(" extract partition filter from untranslatable OR with mixed partition and data references" ) {
429+ withTable(partFilterTableName) {
430+ sql(s " CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
431+ " PARTITIONED BY (part_col)" )
432+ sql(s " INSERT INTO $partFilterTableName VALUES " +
433+ " ('a', 'x'), ('b', 'y'), ('c', 'z')" )
434+
435+ spark.udf.register(" my_upper_extract" , (s : String ) =>
436+ if (s == null ) null else s.toUpperCase(Locale .ROOT ))
437+
438+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
439+ " (my_upper_extract(part_col) = 'A' AND data = 'x') OR " +
440+ " (my_upper_extract(part_col) = 'B' AND data = 'y')" )
441+ checkAnswer(df, Seq (Row (" a" , " x" ), Row (" b" , " y" )))
442+ assertPushedPartitionPredicates(df, 1 )
443+ assertScanReturnsPartitionKeys(df, Set (" a" , " b" ))
444+ assertReferencedPartitionFieldOrdinals(df, Array (0 ), Array (" part_col" ))
445+ }
446+ }
447+
448+ test(" extract partition filter from OR with one partition-only and one mixed filter" ) {
449+ withTable(partFilterTableName) {
450+ sql(s " CREATE TABLE $partFilterTableName (part_col string, data string) USING $v2Source " +
451+ " PARTITIONED BY (part_col)" )
452+ sql(s " INSERT INTO $partFilterTableName VALUES " +
453+ " ('a', 'x'), ('a', 'other'), ('b', 'y'), ('b', 'other'), ('c', 'z')" )
454+
455+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
456+ " part_col = 'a' OR (part_col = 'b' AND data = 'y')" )
457+ checkAnswer(df, Seq (Row (" a" , " x" ), Row (" a" , " other" ), Row (" b" , " y" )))
458+ assertPushedPartitionPredicates(df, 1 )
459+ assertScanReturnsPartitionKeys(df, Set (" a" , " b" ))
460+ assertReferencedPartitionFieldOrdinals(df, Array (0 ), Array (" part_col" ))
461+ }
462+ }
463+
464+ test(" extract multi-column partition filter from OR with mixed partition and data references" ) {
465+ withTable(partFilterTableName) {
466+ sql(s " CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " +
467+ s " USING $v2Source PARTITIONED BY (p1, p2) " )
468+ sql(s " INSERT INTO $partFilterTableName VALUES " +
469+ " ('a', 'x', 'd1'), ('a', 'y', 'd2'), ('b', 'x', 'd3'), ('b', 'y', 'd4')" )
470+
471+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
472+ " (p1 = 'a' AND p2 = 'x' AND data = 'd1') OR (p1 = 'b' AND p2 = 'y' AND data = 'd4')" )
473+ checkAnswer(df, Seq (Row (" a" , " x" , " d1" ), Row (" b" , " y" , " d4" )))
474+ assertPushedPartitionPredicates(df, 1 )
475+ assertScanReturnsPartitionKeys(df, Set (" a/x" , " b/y" ))
476+ assertReferencedPartitionFieldOrdinals(df, Array (0 , 1 ), Array (" p1" , " p2" ))
477+ }
478+ }
479+
480+ test(" two partition predicates pushed: UDF on p1 and " +
481+ " extracted filter on p2 from mixed data and partition references" ) {
482+ withTable(partFilterTableName) {
483+ sql(s " CREATE TABLE $partFilterTableName (p1 string, p2 string, data string) " +
484+ s " USING $v2Source PARTITIONED BY (p1, p2) " )
485+ sql(s " INSERT INTO $partFilterTableName VALUES " +
486+ " ('a', 'x', 'd1'), " +
487+ " ('a', 'y', 'd4'), " +
488+ " ('b', 'x', 'd3'), " +
489+ " ('b', 'y', 'd4'), " +
490+ " ('c', 'z', 'd5')" )
491+
492+ spark.udf.register(" my_upper_multi" , (s : String ) =>
493+ if (s == null ) null else s.toUpperCase(Locale .ROOT ))
494+
495+ // my_upper_multi(p1) = 'A' is untranslatable and partition-only, so it is a partition filter.
496+ // The OR mixes p2 and data; we infer (p2 = 'x' OR p2 = 'y') as a partition filter.
497+ // Both are pushed as separate PartitionPredicates.
498+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
499+ " my_upper_multi(p1) = 'A' AND " +
500+ " ((p2 = 'x' AND data = 'd1') OR (p2 = 'y' AND data = 'd4'))" )
501+ checkAnswer(df, Seq (Row (" a" , " x" , " d1" ), Row (" a" , " y" , " d4" )))
502+ assertPushedPartitionPredicates(df, 2 )
503+ assertScanReturnsPartitionKeys(df, Set (" a/x" , " a/y" ))
504+ }
505+ }
506+
507+ test(" nested partition: extract partition filter from " +
508+ " OR with mixed data and partition references" ) {
509+ withTable(partFilterTableName) {
510+ sql(s " CREATE TABLE $partFilterTableName " +
511+ s " (s struct<tz: string, x: int>, data string) USING $v2Source " +
512+ " PARTITIONED BY (s.tz)" )
513+ sql(s " INSERT INTO $partFilterTableName VALUES " +
514+ " (named_struct('tz', 'LA', 'x', 1), 'a'), " +
515+ " (named_struct('tz', 'NY', 'x', 2), 'b'), " +
516+ " (named_struct('tz', 'SF', 'x', 3), 'c')" )
517+
518+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
519+ " (s.tz = 'LA' AND data = 'a') OR (s.tz = 'NY' AND data = 'b')" )
520+ checkAnswer(df, Seq (Row (Row (" LA" , 1 ), " a" ), Row (Row (" NY" , 2 ), " b" )))
521+ assertPushedPartitionPredicates(df, 1 )
522+ assertScanReturnsPartitionKeys(df, Set (" LA" , " NY" ))
523+ assertReferencedPartitionFieldOrdinals(df, Array (0 ), Array (" s.tz" ))
524+ }
525+ }
526+
527+ test(" nested partition: two partition predicates from " +
528+ " UDF and extracted mixed data and partition references" ) {
529+ withTable(partFilterTableName) {
530+ sql(s " CREATE TABLE $partFilterTableName " +
531+ s " (s struct<tz: string, x: int>, data string) USING $v2Source " +
532+ " PARTITIONED BY (s.tz)" )
533+ sql(s " INSERT INTO $partFilterTableName VALUES " +
534+ " (named_struct('tz', 'LA', 'x', 1), 'a'), " +
535+ " (named_struct('tz', 'la', 'x', 2), 'b'), " +
536+ " (named_struct('tz', 'NY', 'x', 3), 'c'), " +
537+ " (named_struct('tz', 'SF', 'x', 4), 'd')" )
538+
539+ spark.udf.register(" my_upper_nested2" , (s : String ) =>
540+ if (s == null ) null else s.toUpperCase(Locale .ROOT ))
541+
542+ // my_upper_nested2(s.tz) = 'LA' is untranslatable and partition-only,
543+ // it is a partition filter.
544+ // The OR mixes s.tz and data; we infer (s.tz = 'LA' OR s.tz = 'la') as an partition filter.
545+ // Both are pushed as separate PartitionPredicates.
546+ val df = sql(s " SELECT * FROM $partFilterTableName WHERE " +
547+ " my_upper_nested2(s.tz) = 'LA' AND " +
548+ " ((s.tz = 'LA' AND data = 'a') OR (s.tz = 'la' AND data = 'b'))" )
549+ checkAnswer(df, Seq (Row (Row (" LA" , 1 ), " a" ), Row (Row (" la" , 2 ), " b" )))
550+ assertPushedPartitionPredicates(df, 2 )
551+ assertScanReturnsPartitionKeys(df, Set (" LA" , " la" ))
552+ assertReferencedPartitionFieldOrdinals(df, Array (0 ), Array (" s.tz" ))
553+ }
554+ }
555+
412556 private def assertTranslatableBeforeUntranslatableInPostScan (df : DataFrame ): Unit = {
413557 val postScanFilterExec = df.queryExecution.executedPlan.collect {
414558 case f @ FilterExec (_, _) if f.exists(_.isInstanceOf [BatchScanExec ]) => f
0 commit comments