Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result<ast::Statement> {
unparser.plan_to_sql(plan)
}

/// Context in which a child plan may need its own SQL SELECT scope.
enum SelectScopeContext {
/// The direct input of a Window plan. Some inputs must remain below the
/// Window node because their SELECT clauses are evaluated after window
/// expressions in one SELECT block.
WindowInput,
/// The direct child of a SubqueryAlias. Some children must be emitted as a
/// derived table so clauses owned by the alias scope are not flattened into
/// the parent SELECT.
SubqueryAliasChild,
}

impl Unparser<'_> {
pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<ast::Statement> {
let mut plan = normalize_union_schema(plan)?;
Expand Down Expand Up @@ -542,16 +554,9 @@ impl Unparser<'_> {
}

fn window_input_requires_derived_subquery(plan: &LogicalPlan) -> bool {
// These operators either produce a SELECT list or apply SQL clauses
// that are evaluated after window functions in a single SELECT block.
// Keep them below the Window node by emitting a derived table.
matches!(
Self::plan_requires_independent_select_scope(
plan,
LogicalPlan::Projection(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Union(_)
SelectScopeContext::WindowInput,
)
}

Expand Down Expand Up @@ -1896,17 +1901,39 @@ impl Unparser<'_> {

/// Returns true if a plan, when used as the direct child of a SubqueryAlias,
/// must be emitted as a derived subquery `(SELECT ...) AS alias`.
///
/// Plans like Aggregate or Window build their own SELECT clauses (GROUP BY,
/// window functions).
fn requires_derived_subquery(plan: &LogicalPlan) -> bool {
matches!(
Self::plan_requires_independent_select_scope(
plan,
LogicalPlan::Aggregate(_)
| LogicalPlan::Window(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Union(_)
SelectScopeContext::SubqueryAliasChild,
)
}

/// Returns true when `plan` must keep its own SELECT scope for `context`.
///
/// Each context protects a different SQL boundary: window inputs preserve
/// child clauses evaluated after window expressions, while SubqueryAlias
/// children preserve clauses owned by the aliased derived-table scope.
fn plan_requires_independent_select_scope(
plan: &LogicalPlan,
context: SelectScopeContext,
) -> bool {
matches!(
(context, plan),
(
SelectScopeContext::WindowInput,
LogicalPlan::Projection(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Union(_),
) | (
SelectScopeContext::SubqueryAliasChild,
LogicalPlan::Aggregate(_)
| LogicalPlan::Window(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Union(_),
)
)
}

Expand Down
137 changes: 136 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion_expr::test::function_stub::{
};
use datafusion_expr::{
ColumnarValue, EmptyRelation, Expr, Extension, LogicalPlan, LogicalPlanBuilder,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Union,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SortExpr, Union,
UserDefinedLogicalNode, UserDefinedLogicalNodeCore, Volatility, WindowFrame,
WindowFunctionDefinition, cast, col, exists, in_subquery, lit, scalar_subquery,
table_scan, wildcard,
Expand Down Expand Up @@ -70,6 +70,22 @@ use datafusion_sql::unparser::extension_unparser::{
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
use sqlparser::parser::Parser;

fn row_number_over(order_by: SortExpr, alias: &str) -> Expr {
Expr::WindowFunction(Box::new(WindowFunction {
fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()),
params: WindowFunctionParams {
args: vec![],
partition_by: vec![],
order_by: vec![order_by],
window_frame: WindowFrame::new(None),
null_treatment: None,
distinct: false,
filter: None,
},
}))
.alias(alias)
}

#[test]
fn test_roundtrip_expr_1() {
let expr = roundtrip_expr(TableReference::bare("person"), "age > 35").unwrap();
Expand Down Expand Up @@ -1526,6 +1542,75 @@ fn test_table_scan_alias() -> Result<()> {
Ok(())
}

#[test]
fn test_unparse_subquery_alias_select_scope_boundaries() -> Result<()> {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("age", DataType::Int32, false),
]);

let aggregate_child = table_scan(Some("t1"), &schema, None)?
.aggregate(vec![col("id")], vec![sum(col("age")).alias("total_age")])?
.alias("a")?
.build()?;
assert_snapshot!(
plan_to_sql(&aggregate_child)?,
@"SELECT * FROM (SELECT sum(t1.age) AS total_age, t1.id FROM t1 GROUP BY t1.id) AS a"
);

let window_expr = row_number_over(col("age").sort(true, true), "row_idx");
let window_child = table_scan(Some("t1"), &schema, None)?
.window(vec![window_expr])?
.alias("a")?
.build()?;
assert_snapshot!(
plan_to_sql(&window_child)?,
@"SELECT * FROM (SELECT *, row_number() OVER (ORDER BY t1.age ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS row_idx FROM t1) AS a"
);

let sort_child = table_scan(Some("t1"), &schema, None)?
.sort(vec![col("age").sort(false, false)])?
.alias("a")?
.build()?;
assert_snapshot!(
plan_to_sql(&sort_child)?,
@"SELECT * FROM (SELECT * FROM t1 ORDER BY t1.age DESC NULLS LAST) AS a"
);

let limit_child = table_scan(Some("t1"), &schema, None)?
.limit(0, Some(5))?
.alias("a")?
.build()?;
assert_snapshot!(
plan_to_sql(&limit_child)?,
@"SELECT * FROM (SELECT * FROM t1 LIMIT 5) AS a"
);

let union_schema = Arc::new(DFSchema::try_from(Schema::new(vec![Field::new(
"id",
DataType::Int32,
false,
)]))?);
let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: union_schema.clone(),
});
let union_child = LogicalPlan::Union(Union {
inputs: vec![
project(empty.clone(), vec![lit(1).alias("id")])?.into(),
project(empty, vec![lit(2).alias("id")])?.into(),
],
schema: union_schema,
});
let union_child = LogicalPlanBuilder::from(union_child).alias("a")?.build()?;
assert_snapshot!(
plan_to_sql(&union_child)?,
@"SELECT * FROM (SELECT 1 AS id UNION ALL SELECT 2 AS id) AS a"
);

Ok(())
}

#[test]
fn test_table_scan_pushdown() -> Result<()> {
let schema = Schema::new(vec![
Expand Down Expand Up @@ -3128,6 +3213,56 @@ fn test_unparse_window_over_projection_without_projection() -> Result<()> {
Ok(())
}

#[test]
fn test_unparse_window_over_sort_without_projection() -> Result<()> {
let schema = Schema::new(vec![
Field::new("k", DataType::Int32, false),
Field::new("v", DataType::Int32, false),
]);
let window_expr = row_number_over(col("v").sort(true, true), "row_idx");
let plan = table_scan(Some("test"), &schema, None)?
.sort(vec![col("v").sort(false, false)])?
.window(vec![window_expr])?
.build()?;

assert_snapshot!(
Unparser::default().plan_to_sql(&plan)?,
@"SELECT *, row_number() OVER (ORDER BY derived_window_input.v ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS row_idx FROM (SELECT * FROM test ORDER BY test.v DESC NULLS LAST) AS derived_window_input"
);

Ok(())
}

#[test]
fn test_unparse_window_over_union_without_projection() -> Result<()> {
let schema = Arc::new(DFSchema::try_from(Schema::new(vec![
Field::new("k", DataType::Int32, false),
Field::new("v", DataType::Int32, false),
]))?);
let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: schema.clone(),
});
let union = LogicalPlan::Union(Union {
inputs: vec![
project(empty.clone(), vec![lit(1).alias("k"), lit(10).alias("v")])?.into(),
project(empty, vec![lit(2).alias("k"), lit(20).alias("v")])?.into(),
],
schema,
});
let window_expr = row_number_over(col("v").sort(true, true), "row_idx");
let plan = LogicalPlanBuilder::from(union)
.window(vec![window_expr])?
.build()?;

assert_snapshot!(
Unparser::default().plan_to_sql(&plan)?,
@"SELECT *, row_number() OVER (ORDER BY derived_window_input.v ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS row_idx FROM (SELECT 1 AS k, 10 AS v UNION ALL SELECT 2 AS k, 20 AS v) AS derived_window_input"
);

Ok(())
}

#[test]
fn test_unparse_window_over_derived_aggregate_without_projection() -> Result<()> {
let schema = Schema::new(vec![
Expand Down
Loading