@@ -6,14 +6,14 @@ use std::{
66
77use datafusion:: {
88 catalog:: TableProvider ,
9- common:: { tree_node:: Transformed , Column } ,
9+ common:: { tree_node:: Transformed , Column , DFSchema } ,
1010 datasource:: { empty:: EmptyTable , DefaultTableSource } ,
1111 error:: DataFusionError ,
1212 sql:: TableReference ,
1313} ;
1414use datafusion_expr:: {
15- build_join_schema, Aggregate , Expr , Filter , Join , JoinConstraint , JoinType , LogicalPlan ,
16- Projection , SubqueryAlias , TableScan , Union ,
15+ build_join_schema, expr :: Alias , Aggregate , Expr , Filter , Join , JoinConstraint , JoinType ,
16+ LogicalPlan , Projection , SubqueryAlias , TableScan , Union ,
1717} ;
1818use iceberg_rust:: error:: Error ;
1919
@@ -82,9 +82,7 @@ pub(crate) fn delta_transform_down(
8282 LogicalPlan :: Join ( join) => {
8383 let inputs = transform_join ( join) ?. 0 ;
8484
85- Ok ( Transformed :: yes ( LogicalPlan :: Union (
86- Union :: try_new_by_name ( inputs) ?,
87- ) ) )
85+ Ok ( Transformed :: yes ( union_with_schema ( inputs, & join. schema ) ?) )
8886 }
8987 LogicalPlan :: Union ( union) => {
9088 let inputs = union
@@ -190,9 +188,10 @@ pub(crate) fn delta_transform_down(
190188 } ) ) ;
191189
192190 let inputs = vec ! [ aggregate_projection, anti_join] ;
193- Ok ( Transformed :: yes ( LogicalPlan :: Union (
194- Union :: try_new_by_name ( inputs) ?,
195- ) ) )
191+ Ok ( Transformed :: yes ( union_with_schema (
192+ inputs,
193+ & aggregate. schema ,
194+ ) ?) )
196195 }
197196 LogicalPlan :: TableScan ( scan) => {
198197 let mut scan = scan. clone ( ) ;
@@ -279,9 +278,7 @@ pub(crate) fn delta_transform_down(
279278 Arc :: new( left_delta) ,
280279 Arc :: new( right_delta) ,
281280 ] ;
282- Ok ( Transformed :: yes ( LogicalPlan :: Union (
283- Union :: try_new_by_name ( inputs) ?,
284- ) ) )
281+ Ok ( Transformed :: yes ( union_with_schema ( inputs, & join. schema ) ?) )
285282 }
286283 LogicalPlan :: Union ( union) => {
287284 let inputs = union
@@ -513,6 +510,27 @@ fn storage_table_group_expressions(
513510 . collect ( )
514511}
515512
513+ fn union_with_schema (
514+ inputs : Vec < Arc < LogicalPlan > > ,
515+ schema : & DFSchema ,
516+ ) -> Result < LogicalPlan , DataFusionError > {
517+ let union = Union :: try_new_by_name ( inputs) ?;
518+ let exprs = schema
519+ . iter ( )
520+ . map ( |( reference, field) | {
521+ Expr :: Alias ( Alias :: new (
522+ Expr :: Column ( Column :: new ( None :: < String > , field. name ( ) ) ) ,
523+ reference. cloned ( ) ,
524+ field. name ( ) ,
525+ ) )
526+ } )
527+ . collect :: < Vec < _ > > ( ) ;
528+ Ok ( LogicalPlan :: Projection ( Projection :: try_new (
529+ exprs,
530+ Arc :: new ( LogicalPlan :: Union ( union) ) ,
531+ ) ?) )
532+ }
533+
516534#[ cfg( test) ]
517535mod tests {
518536 use core:: panic;
@@ -812,73 +830,89 @@ mod tests {
812830 dbg ! ( & output) ;
813831
814832 if let LogicalPlan :: Projection ( proj) = output {
815- if let LogicalPlan :: Union ( union) = proj. input . deref ( ) {
816- if let LogicalPlan :: Join ( join) = union. inputs [ 0 ] . deref ( ) {
817- if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
818- if let Some ( ext) = ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) {
819- if let LogicalPlan :: TableScan ( table) = ext. input . deref ( ) {
820- assert_eq ! ( table. table_name. table( ) , "users" )
833+ if let LogicalPlan :: Projection ( proj) = proj. input . deref ( ) {
834+ if let LogicalPlan :: Union ( union) = proj. input . deref ( ) {
835+ if let LogicalPlan :: Projection ( proj) = union. inputs [ 0 ] . deref ( ) {
836+ if let LogicalPlan :: Join ( join) = proj. input . deref ( ) {
837+ if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
838+ if let Some ( ext) = ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) {
839+ if let LogicalPlan :: TableScan ( table) = ext. input . deref ( ) {
840+ assert_eq ! ( table. table_name. table( ) , "users" )
841+ } else {
842+ panic ! ( "Node is not a table scan." )
843+ }
844+ } else {
845+ panic ! ( "Node is not a ForkNode" )
846+ }
821847 } else {
822- panic ! ( "Node is not a table scan. " )
848+ panic ! ( "Node is not an extension " )
823849 }
824- } else {
825- panic ! ( "Node is not a ForkNode" )
826- }
827- } else {
828- panic ! ( "Node is not an extension" )
829- }
830- if let LogicalPlan :: Extension ( ext ) = join . right . deref ( ) {
831- if let Some ( ext ) = ext . node . as_any ( ) . downcast_ref :: < ForkNode > ( ) {
832- if let LogicalPlan :: TableScan ( table ) = ext . input . deref ( ) {
833- assert_eq ! ( table . table_name . table ( ) , "homes" )
850+ if let LogicalPlan :: Extension ( ext ) = join . right . deref ( ) {
851+ if let Some ( ext ) = ext . node . as_any ( ) . downcast_ref :: < ForkNode > ( ) {
852+ if let LogicalPlan :: TableScan ( table ) = ext . input . deref ( ) {
853+ assert_eq ! ( table . table_name . table ( ) , "homes" )
854+ } else {
855+ panic ! ( "Node is not a table scan." )
856+ }
857+ } else {
858+ panic ! ( "Node is not a ForkNode" )
859+ }
834860 } else {
835- panic ! ( "Node is not a table scan. " )
861+ panic ! ( "Node is not an extension " )
836862 }
837863 } else {
838- panic ! ( "Node is not a ForkNode" )
839- }
840- } else {
841- panic ! ( "Node is not an extension" )
842- }
843- } else {
844- panic ! ( "Node is not a CrossJoin." )
845- }
846- if let LogicalPlan :: Join ( join) = union. inputs [ 1 ] . deref ( ) {
847- if let LogicalPlan :: TableScan ( table) = join. left . deref ( ) {
848- assert_eq ! ( table. table_name. table( ) , "users" )
849- } else {
850- panic ! ( "Node is not a table scan." )
851- }
852- if let LogicalPlan :: Extension ( ext) = join. right . deref ( ) {
853- if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
854- } else {
855- panic ! ( "Node is not a ForkNode" )
864+ panic ! ( "Node is not a CrossJoin." )
856865 }
857- } else {
858- panic ! ( "Node is not an extension" )
859- }
860- } else {
861- panic ! ( "Node is not a CrossJoin." )
862- }
863- if let LogicalPlan :: Join ( join) = union. inputs [ 2 ] . deref ( ) {
864- if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
865- if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
866+ if let LogicalPlan :: Projection ( proj) = union. inputs [ 1 ] . deref ( ) {
867+ if let LogicalPlan :: Join ( join) = proj. input . deref ( ) {
868+ if let LogicalPlan :: TableScan ( table) = join. left . deref ( ) {
869+ assert_eq ! ( table. table_name. table( ) , "users" )
870+ } else {
871+ panic ! ( "Node is not a table scan." )
872+ }
873+ if let LogicalPlan :: Extension ( ext) = join. right . deref ( ) {
874+ if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
875+ } else {
876+ panic ! ( "Node is not a ForkNode" )
877+ }
878+ } else {
879+ panic ! ( "Node is not an extension" )
880+ }
881+ } else {
882+ panic ! ( "Node is not a CrossJoin." )
883+ }
884+ if let LogicalPlan :: Projection ( proj) = union. inputs [ 2 ] . deref ( ) {
885+ if let LogicalPlan :: Join ( join) = proj. input . deref ( ) {
886+ if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
887+ if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
888+ } else {
889+ panic ! ( "Node is not a RecveiverNode" )
890+ }
891+ } else {
892+ panic ! ( "Node is not an extension" )
893+ }
894+ if let LogicalPlan :: TableScan ( table) = join. right . deref ( ) {
895+ assert_eq ! ( table. table_name. table( ) , "homes" )
896+ } else {
897+ panic ! ( "Node is not a table scan." )
898+ }
899+ } else {
900+ panic ! ( "Node is not a CrossJoin." )
901+ }
902+ } else {
903+ panic ! ( "Node is not a projection." )
904+ }
866905 } else {
867- panic ! ( "Node is not a RecveiverNode " )
906+ panic ! ( "Node is not a projection. " )
868907 }
869908 } else {
870- panic ! ( "Node is not an extension" )
871- }
872- if let LogicalPlan :: TableScan ( table) = join. right . deref ( ) {
873- assert_eq ! ( table. table_name. table( ) , "homes" )
874- } else {
875- panic ! ( "Node is not a table scan." )
909+ panic ! ( "Node is not a projection." )
876910 }
877911 } else {
878- panic ! ( "Node is not a CrossJoin ." )
912+ panic ! ( "Node is not a filter ." )
879913 }
880914 } else {
881- panic ! ( "Node is not a filter ." )
915+ panic ! ( "Node is not a projection ." )
882916 }
883917 } else {
884918 panic ! ( "Node is not a projection." )
@@ -1057,67 +1091,84 @@ mod tests {
10571091 dbg ! ( & output) ;
10581092
10591093 if let LogicalPlan :: Projection ( proj) = output {
1060- if let LogicalPlan :: Union ( union) = proj. input . deref ( ) {
1061- if let LogicalPlan :: Projection ( proj) = union. inputs [ 0 ] . deref ( ) {
1062- if let LogicalPlan :: Join ( join) = proj. input . deref ( ) {
1063- if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
1064- if let Some ( ext) = ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) {
1065- if let LogicalPlan :: Aggregate ( aggregate) = ext. input . deref ( ) {
1066- if let LogicalPlan :: TableScan ( table) = aggregate. input . deref ( ) {
1067- assert_eq ! ( table. table_name. table( ) , "users" )
1094+ if let LogicalPlan :: Projection ( proj) = proj. input . deref ( ) {
1095+ if let LogicalPlan :: Union ( union) = proj. input . deref ( ) {
1096+ if let LogicalPlan :: Projection ( proj) = union. inputs [ 0 ] . deref ( ) {
1097+ if let LogicalPlan :: Projection ( proj) = proj. input . deref ( ) {
1098+ if let LogicalPlan :: Join ( join) = proj. input . deref ( ) {
1099+ if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
1100+ if let Some ( ext) = ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( )
1101+ {
1102+ if let LogicalPlan :: Aggregate ( aggregate) = ext. input . deref ( )
1103+ {
1104+ if let LogicalPlan :: TableScan ( table) =
1105+ aggregate. input . deref ( )
1106+ {
1107+ assert_eq ! ( table. table_name. table( ) , "users" )
1108+ } else {
1109+ panic ! ( "Node is not a table scan." )
1110+ }
1111+ } else {
1112+ panic ! ( "Node is not an aggregate." )
1113+ }
10681114 } else {
1069- panic ! ( "Node is not a table scan." )
1115+ panic ! ( "Node is not a ForkNode" )
1116+ }
1117+ } else {
1118+ panic ! ( "Node is not an extension" )
1119+ }
1120+ if let LogicalPlan :: Extension ( ext) = join. right . deref ( ) {
1121+ if let Some ( ext) = ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( )
1122+ {
1123+ if let LogicalPlan :: TableScan ( table) = ext. input . deref ( ) {
1124+ assert_eq ! ( table. table_name. table( ) , "users_view" )
1125+ } else {
1126+ panic ! ( "Node is not a table scan." )
1127+ }
1128+ } else {
1129+ panic ! ( "Node is not a ForkNode" )
10701130 }
10711131 } else {
1072- panic ! ( "Node is not an aggregate. " )
1132+ panic ! ( "Node is not an extension " )
10731133 }
10741134 } else {
1075- panic ! ( "Node is not a ForkNode " )
1135+ panic ! ( "Node is not a CrossJoin. " )
10761136 }
10771137 } else {
1078- panic ! ( "Node is not an extension " )
1138+ panic ! ( "Node is not a projection. " )
10791139 }
1080- if let LogicalPlan :: Extension ( ext) = join. right . deref ( ) {
1081- if let Some ( ext) = ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) {
1082- if let LogicalPlan :: TableScan ( table) = ext. input . deref ( ) {
1083- assert_eq ! ( table. table_name. table( ) , "users_view" )
1140+ } else {
1141+ panic ! ( "Node is not a projection." )
1142+ }
1143+ if let LogicalPlan :: Projection ( proj) = union. inputs [ 1 ] . deref ( ) {
1144+ if let LogicalPlan :: Join ( join) = proj. input . deref ( ) {
1145+ if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
1146+ if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
10841147 } else {
1085- panic ! ( "Node is not a table scan. " )
1148+ panic ! ( "Node is not a ForkNode " )
10861149 }
10871150 } else {
1088- panic ! ( "Node is not a ForkNode" )
1151+ panic ! ( "Node is not an extension" )
1152+ }
1153+ if let LogicalPlan :: Extension ( ext) = join. right . deref ( ) {
1154+ if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
1155+ } else {
1156+ panic ! ( "Node is not a ForkNode" )
1157+ }
1158+ } else {
1159+ panic ! ( "Node is not an extension" )
10891160 }
10901161 } else {
1091- panic ! ( "Node is not an extension " )
1162+ panic ! ( "Node is not a CrossJoin. " )
10921163 }
10931164 } else {
1094- panic ! ( "Node is not a CrossJoin ." )
1165+ panic ! ( "Node is not a filter ." )
10951166 }
10961167 } else {
10971168 panic ! ( "Node is not a projection." )
10981169 }
1099- if let LogicalPlan :: Join ( join) = union. inputs [ 1 ] . deref ( ) {
1100- if let LogicalPlan :: Extension ( ext) = join. left . deref ( ) {
1101- if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
1102- } else {
1103- panic ! ( "Node is not a ForkNode" )
1104- }
1105- } else {
1106- panic ! ( "Node is not an extension" )
1107- }
1108- if let LogicalPlan :: Extension ( ext) = join. right . deref ( ) {
1109- if ext. node . as_any ( ) . downcast_ref :: < ForkNode > ( ) . is_some ( ) {
1110- } else {
1111- panic ! ( "Node is not a ForkNode" )
1112- }
1113- } else {
1114- panic ! ( "Node is not an extension" )
1115- }
1116- } else {
1117- panic ! ( "Node is not a CrossJoin." )
1118- }
11191170 } else {
1120- panic ! ( "Node is not a filter ." )
1171+ panic ! ( "Node is not a projection ." )
11211172 }
11221173 } else {
11231174 panic ! ( "Node is not a projection." )
0 commit comments