Skip to content

Commit 991040c

Browse files
committed
make projection after internal union
1 parent 020c652 commit 991040c

1 file changed

Lines changed: 158 additions & 107 deletions

File tree

  • datafusion_iceberg/src/materialized_view/delta_queries

datafusion_iceberg/src/materialized_view/delta_queries/transform.rs

Lines changed: 158 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ use std::{
66

77
use datafusion::{
88
catalog::TableProvider,
9-
common::{tree_node::Transformed, Column},
9+
common::{tree_node::Transformed, Column, DFSchema},
1010
datasource::{empty::EmptyTable, DefaultTableSource},
1111
error::DataFusionError,
1212
sql::TableReference,
1313
};
1414
use datafusion_expr::{
15-
build_join_schema, Aggregate, Expr, Filter, Join, JoinConstraint, JoinType, LogicalPlan,
16-
Projection, SubqueryAlias, TableScan, Union,
15+
build_join_schema, expr::Alias, Aggregate, Expr, Filter, Join, JoinConstraint, JoinType,
16+
LogicalPlan, Projection, SubqueryAlias, TableScan, Union,
1717
};
1818
use iceberg_rust::error::Error;
1919

@@ -82,9 +82,7 @@ pub(crate) fn delta_transform_down(
8282
LogicalPlan::Join(join) => {
8383
let inputs = transform_join(join)?.0;
8484

85-
Ok(Transformed::yes(LogicalPlan::Union(
86-
Union::try_new_by_name(inputs)?,
87-
)))
85+
Ok(Transformed::yes(union_with_schema(inputs, &join.schema)?))
8886
}
8987
LogicalPlan::Union(union) => {
9088
let inputs = union
@@ -190,9 +188,10 @@ pub(crate) fn delta_transform_down(
190188
}));
191189

192190
let inputs = vec![aggregate_projection, anti_join];
193-
Ok(Transformed::yes(LogicalPlan::Union(
194-
Union::try_new_by_name(inputs)?,
195-
)))
191+
Ok(Transformed::yes(union_with_schema(
192+
inputs,
193+
&aggregate.schema,
194+
)?))
196195
}
197196
LogicalPlan::TableScan(scan) => {
198197
let mut scan = scan.clone();
@@ -279,9 +278,7 @@ pub(crate) fn delta_transform_down(
279278
Arc::new(left_delta),
280279
Arc::new(right_delta),
281280
];
282-
Ok(Transformed::yes(LogicalPlan::Union(
283-
Union::try_new_by_name(inputs)?,
284-
)))
281+
Ok(Transformed::yes(union_with_schema(inputs, &join.schema)?))
285282
}
286283
LogicalPlan::Union(union) => {
287284
let inputs = union
@@ -513,6 +510,27 @@ fn storage_table_group_expressions(
513510
.collect()
514511
}
515512

513+
fn union_with_schema(
514+
inputs: Vec<Arc<LogicalPlan>>,
515+
schema: &DFSchema,
516+
) -> Result<LogicalPlan, DataFusionError> {
517+
let union = Union::try_new_by_name(inputs)?;
518+
let exprs = schema
519+
.iter()
520+
.map(|(reference, field)| {
521+
Expr::Alias(Alias::new(
522+
Expr::Column(Column::new(None::<String>, field.name())),
523+
reference.cloned(),
524+
field.name(),
525+
))
526+
})
527+
.collect::<Vec<_>>();
528+
Ok(LogicalPlan::Projection(Projection::try_new(
529+
exprs,
530+
Arc::new(LogicalPlan::Union(union)),
531+
)?))
532+
}
533+
516534
#[cfg(test)]
517535
mod tests {
518536
use core::panic;
@@ -812,73 +830,89 @@ mod tests {
812830
dbg!(&output);
813831

814832
if let LogicalPlan::Projection(proj) = output {
815-
if let LogicalPlan::Union(union) = proj.input.deref() {
816-
if let LogicalPlan::Join(join) = union.inputs[0].deref() {
817-
if let LogicalPlan::Extension(ext) = join.left.deref() {
818-
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>() {
819-
if let LogicalPlan::TableScan(table) = ext.input.deref() {
820-
assert_eq!(table.table_name.table(), "users")
833+
if let LogicalPlan::Projection(proj) = proj.input.deref() {
834+
if let LogicalPlan::Union(union) = proj.input.deref() {
835+
if let LogicalPlan::Projection(proj) = union.inputs[0].deref() {
836+
if let LogicalPlan::Join(join) = proj.input.deref() {
837+
if let LogicalPlan::Extension(ext) = join.left.deref() {
838+
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>() {
839+
if let LogicalPlan::TableScan(table) = ext.input.deref() {
840+
assert_eq!(table.table_name.table(), "users")
841+
} else {
842+
panic!("Node is not a table scan.")
843+
}
844+
} else {
845+
panic!("Node is not a ForkNode")
846+
}
821847
} else {
822-
panic!("Node is not a table scan.")
848+
panic!("Node is not an extension")
823849
}
824-
} else {
825-
panic!("Node is not a ForkNode")
826-
}
827-
} else {
828-
panic!("Node is not an extension")
829-
}
830-
if let LogicalPlan::Extension(ext) = join.right.deref() {
831-
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>() {
832-
if let LogicalPlan::TableScan(table) = ext.input.deref() {
833-
assert_eq!(table.table_name.table(), "homes")
850+
if let LogicalPlan::Extension(ext) = join.right.deref() {
851+
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>() {
852+
if let LogicalPlan::TableScan(table) = ext.input.deref() {
853+
assert_eq!(table.table_name.table(), "homes")
854+
} else {
855+
panic!("Node is not a table scan.")
856+
}
857+
} else {
858+
panic!("Node is not a ForkNode")
859+
}
834860
} else {
835-
panic!("Node is not a table scan.")
861+
panic!("Node is not an extension")
836862
}
837863
} else {
838-
panic!("Node is not a ForkNode")
839-
}
840-
} else {
841-
panic!("Node is not an extension")
842-
}
843-
} else {
844-
panic!("Node is not a CrossJoin.")
845-
}
846-
if let LogicalPlan::Join(join) = union.inputs[1].deref() {
847-
if let LogicalPlan::TableScan(table) = join.left.deref() {
848-
assert_eq!(table.table_name.table(), "users")
849-
} else {
850-
panic!("Node is not a table scan.")
851-
}
852-
if let LogicalPlan::Extension(ext) = join.right.deref() {
853-
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
854-
} else {
855-
panic!("Node is not a ForkNode")
864+
panic!("Node is not a CrossJoin.")
856865
}
857-
} else {
858-
panic!("Node is not an extension")
859-
}
860-
} else {
861-
panic!("Node is not a CrossJoin.")
862-
}
863-
if let LogicalPlan::Join(join) = union.inputs[2].deref() {
864-
if let LogicalPlan::Extension(ext) = join.left.deref() {
865-
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
866+
if let LogicalPlan::Projection(proj) = union.inputs[1].deref() {
867+
if let LogicalPlan::Join(join) = proj.input.deref() {
868+
if let LogicalPlan::TableScan(table) = join.left.deref() {
869+
assert_eq!(table.table_name.table(), "users")
870+
} else {
871+
panic!("Node is not a table scan.")
872+
}
873+
if let LogicalPlan::Extension(ext) = join.right.deref() {
874+
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
875+
} else {
876+
panic!("Node is not a ForkNode")
877+
}
878+
} else {
879+
panic!("Node is not an extension")
880+
}
881+
} else {
882+
panic!("Node is not a CrossJoin.")
883+
}
884+
if let LogicalPlan::Projection(proj) = union.inputs[2].deref() {
885+
if let LogicalPlan::Join(join) = proj.input.deref() {
886+
if let LogicalPlan::Extension(ext) = join.left.deref() {
887+
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
888+
} else {
889+
panic!("Node is not a RecveiverNode")
890+
}
891+
} else {
892+
panic!("Node is not an extension")
893+
}
894+
if let LogicalPlan::TableScan(table) = join.right.deref() {
895+
assert_eq!(table.table_name.table(), "homes")
896+
} else {
897+
panic!("Node is not a table scan.")
898+
}
899+
} else {
900+
panic!("Node is not a CrossJoin.")
901+
}
902+
} else {
903+
panic!("Node is not a projection.")
904+
}
866905
} else {
867-
panic!("Node is not a RecveiverNode")
906+
panic!("Node is not a projection.")
868907
}
869908
} else {
870-
panic!("Node is not an extension")
871-
}
872-
if let LogicalPlan::TableScan(table) = join.right.deref() {
873-
assert_eq!(table.table_name.table(), "homes")
874-
} else {
875-
panic!("Node is not a table scan.")
909+
panic!("Node is not a projection.")
876910
}
877911
} else {
878-
panic!("Node is not a CrossJoin.")
912+
panic!("Node is not a filter.")
879913
}
880914
} else {
881-
panic!("Node is not a filter.")
915+
panic!("Node is not a projection.")
882916
}
883917
} else {
884918
panic!("Node is not a projection.")
@@ -1057,67 +1091,84 @@ mod tests {
10571091
dbg!(&output);
10581092

10591093
if let LogicalPlan::Projection(proj) = output {
1060-
if let LogicalPlan::Union(union) = proj.input.deref() {
1061-
if let LogicalPlan::Projection(proj) = union.inputs[0].deref() {
1062-
if let LogicalPlan::Join(join) = proj.input.deref() {
1063-
if let LogicalPlan::Extension(ext) = join.left.deref() {
1064-
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>() {
1065-
if let LogicalPlan::Aggregate(aggregate) = ext.input.deref() {
1066-
if let LogicalPlan::TableScan(table) = aggregate.input.deref() {
1067-
assert_eq!(table.table_name.table(), "users")
1094+
if let LogicalPlan::Projection(proj) = proj.input.deref() {
1095+
if let LogicalPlan::Union(union) = proj.input.deref() {
1096+
if let LogicalPlan::Projection(proj) = union.inputs[0].deref() {
1097+
if let LogicalPlan::Projection(proj) = proj.input.deref() {
1098+
if let LogicalPlan::Join(join) = proj.input.deref() {
1099+
if let LogicalPlan::Extension(ext) = join.left.deref() {
1100+
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>()
1101+
{
1102+
if let LogicalPlan::Aggregate(aggregate) = ext.input.deref()
1103+
{
1104+
if let LogicalPlan::TableScan(table) =
1105+
aggregate.input.deref()
1106+
{
1107+
assert_eq!(table.table_name.table(), "users")
1108+
} else {
1109+
panic!("Node is not a table scan.")
1110+
}
1111+
} else {
1112+
panic!("Node is not an aggregate.")
1113+
}
10681114
} else {
1069-
panic!("Node is not a table scan.")
1115+
panic!("Node is not a ForkNode")
1116+
}
1117+
} else {
1118+
panic!("Node is not an extension")
1119+
}
1120+
if let LogicalPlan::Extension(ext) = join.right.deref() {
1121+
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>()
1122+
{
1123+
if let LogicalPlan::TableScan(table) = ext.input.deref() {
1124+
assert_eq!(table.table_name.table(), "users_view")
1125+
} else {
1126+
panic!("Node is not a table scan.")
1127+
}
1128+
} else {
1129+
panic!("Node is not a ForkNode")
10701130
}
10711131
} else {
1072-
panic!("Node is not an aggregate.")
1132+
panic!("Node is not an extension")
10731133
}
10741134
} else {
1075-
panic!("Node is not a ForkNode")
1135+
panic!("Node is not a CrossJoin.")
10761136
}
10771137
} else {
1078-
panic!("Node is not an extension")
1138+
panic!("Node is not a projection.")
10791139
}
1080-
if let LogicalPlan::Extension(ext) = join.right.deref() {
1081-
if let Some(ext) = ext.node.as_any().downcast_ref::<ForkNode>() {
1082-
if let LogicalPlan::TableScan(table) = ext.input.deref() {
1083-
assert_eq!(table.table_name.table(), "users_view")
1140+
} else {
1141+
panic!("Node is not a projection.")
1142+
}
1143+
if let LogicalPlan::Projection(proj) = union.inputs[1].deref() {
1144+
if let LogicalPlan::Join(join) = proj.input.deref() {
1145+
if let LogicalPlan::Extension(ext) = join.left.deref() {
1146+
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
10841147
} else {
1085-
panic!("Node is not a table scan.")
1148+
panic!("Node is not a ForkNode")
10861149
}
10871150
} else {
1088-
panic!("Node is not a ForkNode")
1151+
panic!("Node is not an extension")
1152+
}
1153+
if let LogicalPlan::Extension(ext) = join.right.deref() {
1154+
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
1155+
} else {
1156+
panic!("Node is not a ForkNode")
1157+
}
1158+
} else {
1159+
panic!("Node is not an extension")
10891160
}
10901161
} else {
1091-
panic!("Node is not an extension")
1162+
panic!("Node is not a CrossJoin.")
10921163
}
10931164
} else {
1094-
panic!("Node is not a CrossJoin.")
1165+
panic!("Node is not a filter.")
10951166
}
10961167
} else {
10971168
panic!("Node is not a projection.")
10981169
}
1099-
if let LogicalPlan::Join(join) = union.inputs[1].deref() {
1100-
if let LogicalPlan::Extension(ext) = join.left.deref() {
1101-
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
1102-
} else {
1103-
panic!("Node is not a ForkNode")
1104-
}
1105-
} else {
1106-
panic!("Node is not an extension")
1107-
}
1108-
if let LogicalPlan::Extension(ext) = join.right.deref() {
1109-
if ext.node.as_any().downcast_ref::<ForkNode>().is_some() {
1110-
} else {
1111-
panic!("Node is not a ForkNode")
1112-
}
1113-
} else {
1114-
panic!("Node is not an extension")
1115-
}
1116-
} else {
1117-
panic!("Node is not a CrossJoin.")
1118-
}
11191170
} else {
1120-
panic!("Node is not a filter.")
1171+
panic!("Node is not a projection.")
11211172
}
11221173
} else {
11231174
panic!("Node is not a projection.")

0 commit comments

Comments
 (0)